from typing import Callable, Iterable, Mapping, Union, Optional, TYPE_CHECKING
import glob
import numbers
import logging
import enum
import functools
import itertools
import re
import os
from pathlib import Path
import numpy as np
import numpy.typing as npt
import numpy.lib.mixins
import xarray as xr
import cftime
import pygetm.util.interpolate
from pygetm.constants import CENTERS, TimeVarying
from pygetm.parallel import MPI
if TYPE_CHECKING:
import pygetm.core
LATITUDE_UNITS = (
"degrees_north",
"degree_north",
"degree_N",
"degrees_N",
"degreeN",
"degreesN",
)
LONGITUDE_UNITS = (
"degrees_east",
"degree_east",
"degree_E",
"degrees_E",
"degreeE",
"degreesE",
)
Z_STANDARD_NAMES = (
"height",
"height_above_mean_sea_level",
"depth",
"depth_below_geoid",
)
[docs]
@xr.register_dataarray_accessor("getm")
class GETMAccessor:
def __init__(self, xarray_obj: xr.DataArray):
self._obj = xarray_obj
@property
def longitude(self) -> Optional[xr.DataArray]:
return self.coordinates.get("longitude")
@property
def latitude(self) -> Optional[xr.DataArray]:
return self.coordinates.get("latitude")
@property
def z(self) -> Optional[xr.DataArray]:
return self.coordinates.get("z")
@property
def time(self) -> Optional[xr.DataArray]:
return self.coordinates.get("time")
[docs]
@functools.cached_property
def coordinates(self) -> Mapping[str, xr.DataArray]:
prescribed = self._obj.encoding.get("coordinates", "").split()
def priority(name: str) -> int:
if name in prescribed:
return 2
if name in self._obj.indexes:
return 1
return 0
_coordinates = {}
for name in sorted(self._obj.coords, key=priority):
coord = self._obj.coords[name]
units = coord.attrs.get("units")
standard_name = coord.attrs.get("standard_name")
if standard_name in ("latitude", "longitude"):
_coordinates[standard_name] = coord
elif (
coord.attrs.get("positive")
or standard_name in Z_STANDARD_NAMES
or coord.attrs.get("axis") == "Z"
):
_coordinates["z"] = coord
elif units in LATITUDE_UNITS:
_coordinates["latitude"] = coord
elif units in LONGITUDE_UNITS:
_coordinates["longitude"] = coord
elif coord.size > 0 and isinstance(coord.values.flat[0], cftime.datetime):
_coordinates["time"] = coord
elif name == "zax":
_coordinates["z"] = coord
return _coordinates
open_nc_files = []
def _open(path: Union[str, os.PathLike[str]], preprocess=None, **kwargs):
key = (path, preprocess, kwargs.copy())
for k, ds in open_nc_files:
if k == key:
return ds
ds = xr.open_dataset(path, **kwargs)
if preprocess:
ds = preprocess(ds)
open_nc_files.append((key, ds))
return ds
try:
DEFAULT_TIME_DECODER = xr.coders.CFDatetimeCoder(use_cftime=True)
except AttributeError:
DEFAULT_TIME_DECODER = None
[docs]
def from_nc(
paths: Union[str, os.PathLike[str], Iterable[Union[str, os.PathLike[str]]]],
name: str,
preprocess: Optional[Callable[[xr.Dataset], xr.Dataset]] = None,
**kwargs,
) -> xr.DataArray:
"""Obtain a variable from one or more NetCDF files that can be used as value
provided to :meth:`InputManager.add` and :meth:`pygetm.core.Array.set`.
Args:
paths: single file path, a pathname pattern containing `*` and/or `?`, or a
sequence of file paths. If multiple paths are provided (or the pattern
resolves to multiple valid path names), the files will be concatenated
along their time dimension.
preprocess: function that transforms the :class:`xarray.Dataset` opened for
every path provided. This can be used to modify the datasets before
concatenation in time is attempted, for instance, to cut off time indices
that overlap between files.
**kwargs: additional keyword arguments to be passed to
:func:`xarray.open_dataset`
"""
if DEFAULT_TIME_DECODER is not None:
kwargs.setdefault("decode_times", DEFAULT_TIME_DECODER)
else:
# xarray < 2025.01.1
kwargs.setdefault("decode_times", True)
kwargs["use_cftime"] = True
kwargs["cache"] = False
if isinstance(paths, (str, os.PathLike)):
# This is a single item (string or PathLike), not a list of such items.
# Check if it is a URL or a pattern
# https://github.com/pydata/xarray/blob/40c27d19d169ccf1c469255c6c6da327f5822d01/xarray/core/utils.py#L692C17-L692C63
if isinstance(paths, str) and not re.match(r"[a-z][a-z0-9]*(\://|\:\:)", paths):
# Not a URL, but a file path or glob pattern. Cast to list of valid paths.
pattern = paths
paths = glob.glob(pattern)
if not paths:
raise Exception(f"No files found matching {pattern!r}")
paths = map(Path, paths)
else:
# A URL or a single file path (PathLike)
if isinstance(paths, os.PathLike):
if not os.path.exists(paths):
raise FileNotFoundError(f"File not found: {paths}")
paths = (paths,)
arrays = []
for path in paths:
ds = _open(path, preprocess, **kwargs)
array = ds[name]
# Note: we wrap the netCDF array ourselves, in order to support lazy operators
# (e.g., add, multiply)
arrays.append(wrap(array, name=f'from_nc("{path}", {name!r})'))
if len(arrays) == 1:
return arrays[0]
else:
assert all(array.getm.time is not None for array in arrays)
return xr.concat(
sorted(arrays, key=lambda a: a.getm.time.values.flat[0]),
dim=arrays[0].getm.time.dims[0],
coords="minimal",
combine_attrs="drop_conflicts",
)
[docs]
def wrap(array: xr.DataArray, name: Optional[str] = None) -> xr.DataArray:
if name is None:
name = array.name or "wrapped_array"
lazyvar = Wrap(array.variable, name=name)
wrapped_array = xr.DataArray(
lazyvar,
dims=array.dims,
coords=array.coords,
attrs=array.attrs,
name=lazyvar.name,
)
wrapped_array.encoding.update(array.encoding)
return wrapped_array
[docs]
class LazyArray(numpy.lib.mixins.NDArrayOperatorsMixin):
def __init__(self, shape: Iterable[int], dtype: npt.DTypeLike, name: str):
self.shape = tuple(shape)
self.ndim = len(self.shape)
self.dtype = np.dtype(dtype)
self.name = name
@property
def size(self) -> int:
return int(np.prod(self.shape))
[docs]
def update(self, time: cftime.datetime, numtime: np.longdouble) -> bool:
return False
[docs]
def astype(self, dtype, **kwargs) -> np.ndarray:
if dtype == self.dtype:
return self
return self.__array__(dtype)
def __array_function__(self, func, types, args, kwargs):
if func == np.result_type:
args = tuple(x.dtype if isinstance(x, LazyArray) else x for x in args)
return np.result_type(*args)
elif func == np.concatenate:
return Concatenate(*args, **kwargs)
elif func == np.ndim:
return self.ndim
elif func == np.shape:
return self.shape
elif func == np.size:
return self.size
args = tuple(np.asarray(x) if isinstance(x, LazyArray) else x for x in args)
kwargs = dict(
(k, np.asarray(v)) if isinstance(v, LazyArray) else (k, v)
for (k, v) in kwargs.items()
)
return func(*args, **kwargs)
def __array_ufunc__(self, ufunc, method: str, *args, **kwargs):
if method != "__call__":
return NotImplemented
if "out" in kwargs:
return NotImplemented
COMPATIBLE_TYPES = (np.ndarray, numbers.Number, LazyArray, xr.Variable)
for x in args:
if not isinstance(x, COMPATIBLE_TYPES):
return NotImplemented
return UFunc(ufunc, method, *args, **kwargs)
def __array__(self, dtype=None, copy=None) -> np.ndarray:
raise NotImplementedError
def __getitem__(self, slices) -> np.ndarray:
return self.__array__()[slices]
[docs]
def is_time_varying(self) -> bool:
return False
def _finalize_slices(self, slices: tuple):
assert isinstance(slices, tuple)
for i, s in enumerate(slices):
if s is Ellipsis:
slices = (
slices[:i]
+ (slice(None),) * (self.ndim + 1 - len(slices))
+ slices[i + 1 :]
)
break
assert len(slices) == self.ndim
return slices
[docs]
class Operator(LazyArray):
_operator_name = None
def __init__(
self,
*args,
passthrough=(),
dtype: npt.DTypeLike = None,
shape: Optional[tuple[int]] = None,
name: Optional[str] = None,
kwslice: Iterable[str] = (),
**kwargs,
):
def _repr(a) -> str:
if isinstance(a, LazyArray):
return a.name
elif np.ndim(a) == 0:
return str(a)
# Other datatype, typically xarray.Variable.
# Do not call object's custom str/repr, as that will cause evaluation
# (e.g. read from file) of the entire array
return f"{type(a).__name__}(shape={np.shape(a)}, dtype={np.result_type(a)})"
# Unpack unnamed arguments
self.args = []
self.arg_names = []
for arg in args:
assert isinstance(
arg, (np.ndarray, numbers.Number, LazyArray, xr.Variable)
), f"Argument has unknown type {type(arg)}"
# Unpack to LazyArray if possible
if isinstance(arg, xr.Variable) and isinstance(arg._data, LazyArray):
arg = arg._data
self.arg_names.append(_repr(arg))
# If this is a Wrap, unwrap
# (the wrapping was only for ufunc support)
if isinstance(arg, Wrap):
arg = arg._source
self.args.append(arg)
self._lazy_args = [arg for arg in self.args if isinstance(arg, LazyArray)]
# Store keyword arguments as-is (no unpacking)
self.kwargs = kwargs
kwarg_names = {k: _repr(v) for k, v in kwargs.items()}
# Infer shape from positional arguments if not provided
self._sliced_kwargs = []
if shape is None:
shapes = [np.shape(input) for input in args]
shape = np.broadcast_shapes(*shapes)
def broadcast(a):
if np.ndim(a) != 0 and np.shape(a) != shape:
return np.broadcast_to(a, shape)
return a
self.args = tuple(map(broadcast, self.args))
for k in kwslice:
if k in self.kwargs:
self.kwargs[k] = broadcast(self.kwargs[k])
self._sliced_kwargs.append(k)
if dtype is None:
dtype = np.result_type(*self.args)
# Process dimensions for which we can passthrough slices to inputs
# This can be True (= all dimensions), an iterable, or a dictionary mapping
# sliced dimensions to input dimensions (if the current operator adds or
# removes dimensions)
if passthrough is True:
passthrough = range(len(shape))
if not isinstance(passthrough, dict):
passthrough = {i: i for i in passthrough}
self.passthrough = passthrough
assert all([isinstance(dim, int) for dim in self.passthrough]), (
f"Invalid passthrough: {self.passthrough}."
" All entries should be of type int"
)
# Determine which arguments get slices passed through
self._sliced_args = [np.ndim(a) != 0 for a in self.args]
# Generate a name for the variable if not provided
if name is None:
operator_name = self._operator_name or self.__class__.__name__
strargs = ", ".join(self.arg_names)
strkwargs = "".join(f", {k}={v}" for (k, v) in kwarg_names.items())
name = f"{operator_name}({strargs}{strkwargs})"
super().__init__(shape, dtype, name)
[docs]
def update(self, *args) -> bool:
return any([arg.update(*args) for arg in self._lazy_args])
[docs]
def is_time_varying(self) -> bool:
return any(input.is_time_varying() for input in self._lazy_args)
def __getitem__(self, slices) -> np.ndarray:
preslices, postslices = [], []
for i, slc in enumerate(self._finalize_slices(slices)):
if i in self.passthrough:
preslices.append(slc)
if not isinstance(slc, (int, np.integer)):
postslices.append(slice(None))
else:
preslices.append(slice(None))
postslices.append(slc)
preslices = tuple(preslices)
args = [
np.asarray(arg[preslices] if s else arg)
for arg, s in zip(self.args, self._sliced_args)
]
kwargs = self.kwargs
if self._sliced_kwargs:
kw_override = {
k: np.asarray(kwargs[k][preslices]) for k in self._sliced_kwargs
}
kwargs = {**kwargs, **kw_override}
return self.apply(*args, **kwargs)[tuple(postslices)]
[docs]
def apply(self, *args: np.ndarray, dtype=None, **kwargs) -> np.ndarray:
raise NotImplementedError
def __array__(self, dtype=None, copy=None) -> np.ndarray:
args = map(np.asarray, self.args)
kwargs = self.kwargs
if self._sliced_kwargs:
kw_override = {k: np.asarray(kwargs[k]) for k in self._sliced_kwargs}
kwargs = {**kwargs, **kw_override}
return self.apply(*args, dtype=dtype, **kwargs)
[docs]
class UnaryOperator(Operator):
def __init__(self, arg, **kwargs):
super().__init__(arg, **kwargs)
self._source = self.args[0]
self._source_name = self.arg_names[0]
[docs]
class UFunc(Operator):
def __init__(self, ufunc, method: str, *args, **kwargs):
self._operator_name = ufunc.__name__
super().__init__(*args, passthrough=True, kwslice=("where",), **kwargs)
self.ufunc = getattr(ufunc, method)
[docs]
def apply(self, *args: np.ndarray, **kwargs) -> np.ndarray:
return self.ufunc(*args, **kwargs)
[docs]
class Wrap(UnaryOperator):
def __init__(self, source: xr.Variable, name: str):
assert isinstance(source, xr.Variable)
super().__init__(source, passthrough=True, name=name)
[docs]
def apply(self, source: np.ndarray, dtype=None) -> np.ndarray:
return source
[docs]
def update(self, *args) -> bool:
return False
[docs]
class Slice(UnaryOperator):
def __init__(self, source, shape: tuple[int], passthrough):
super().__init__(source, shape=shape, passthrough=passthrough)
self._slices = []
self.passthrough_own_slices = True
def __array__(self, dtype=None, copy=None) -> np.ndarray:
data = np.empty(self.shape, dtype or self.dtype)
for src_slice, tgt_slice in self._slices:
if self.passthrough_own_slices:
data[tgt_slice] = self._source[src_slice]
else:
data[tgt_slice] = np.asarray(self._source, dtype=dtype)[src_slice]
return data
def __getitem__(self, slices) -> np.ndarray:
slices = self._finalize_slices(slices)
shape = []
keep = [
not (i in self.passthrough and isinstance(s, (int, np.integer)))
for i, s in enumerate(slices)
]
for i, (l, s, k) in enumerate(zip(self.shape, slices, keep)):
if not k:
# This dimension will be sliced out
continue
assert isinstance(s, slice), (
f"Dimension {i} has unsupported slice type {type(s)} with value {s!r}."
f" Passthrough: {list(self.passthrough)}"
)
start, stop, step = s.indices(l)
assert i in self.passthrough or (start == 0 and stop == l and step == 1), (
f"invalid slice for dimension {i} with length {l}:"
f" {start}:{stop}:{step}"
)
shape.append(len(range(start, stop, step)))
data = np.empty(shape, self.dtype)
for src_slice, tgt_slice in self._slices:
if self.passthrough_own_slices:
# Forward our own slices to source array
src_slice = list(src_slice)
else:
# Apply our own slices only after retrieving data from source array
# Only passed-through slices provided as argument are passed to source
mid_slice = tuple(itertools.compress(src_slice, keep))
src_slice = [slice(None)] * len(src_slice)
# Forward passed-through slices provided as argument to source array
for iout, iin in self.passthrough.items():
if slices[iout] != slice(None):
assert src_slice[iin] == slice(None), "Merging slices not supported"
src_slice[iin] = slices[iout]
tgt_slice = tuple(itertools.compress(tgt_slice, keep))
values = self._source[tuple(src_slice)]
if not self.passthrough_own_slices:
values = values[mid_slice]
data[tgt_slice] = values
return data
[docs]
class Concatenate(Operator):
def __init__(self, arrays, axis: int = 0, **kwargs):
shape = list(arrays[0].shape)
for array in arrays[1:]:
shape[axis] += array.shape[axis]
assert all(array.shape[i] == l for i, l in enumerate(shape) if i != axis)
self.axis = axis
super().__init__(*arrays, shape=shape, **kwargs)
def __array__(self, dtype=None, copy=None) -> np.ndarray:
arrays = [np.asarray(array, dtype=dtype) for array in self.args]
return np.concatenate(arrays, axis=self.axis)
def __getitem__(self, slices) -> np.ndarray:
slices = list(self._finalize_slices(slices))
axis = self.axis
if isinstance(slices[axis], (int, np.integer)):
if slices[axis] < 0:
slices[axis] += self.shape[axis]
assert slices[axis] >= 0
for array in self.args:
if slices[axis] < array.shape[axis]:
return np.asarray(array[tuple(slices)])
slices[axis] -= array.shape[axis]
raise IndexError
else:
assert isinstance(
slices[axis], slice
), f"Invalid slice for concatenation axis {axis}: {slices[axis]!r}"
outaxis = axis
for s in slices[:axis]:
if isinstance(s, (int, np.integer)):
# This dimension precedes the axis over which we concatenate
# and it will be sliced out from the source arrays
# Therefore, the axis over which we concatenate decreases by 1
outaxis -= 1
start, stop, step = slices[axis].indices(self.shape[axis])
assert step > 0
arrays = []
for array in self.args:
if start < array.shape[axis]:
slices[axis] = slice(start, stop, step)
current = np.asarray(array[tuple(slices)])
start += current.shape[outaxis] * step
arrays.append(current)
start -= array.shape[axis]
stop -= array.shape[axis]
if stop <= 0:
break
return np.concatenate(arrays, axis=outaxis)
[docs]
def limit_region(
source: xr.DataArray,
minlon: float,
maxlon: float,
minlat: float,
maxlat: float,
periodic_lon: bool = False,
verbose: bool = False,
require_2d: bool = True,
) -> xr.DataArray:
if not np.isfinite(minlon) or not np.isfinite(maxlon):
raise Exception(f"Longitude range {minlon} - {maxlon} is not valid")
if not np.isfinite(minlat) or not np.isfinite(maxlat):
raise Exception(f"Latitude range {minlat} - {maxlat} is not valid")
if minlon > maxlon:
raise Exception(
f"Invalid longitude range: maximum {maxlon} must be >= minimum {minlon}."
)
if minlat > maxlat:
raise Exception(
f"Invalid latitude range: maximum {maxlat} must be >= minimum {minlat}."
)
source_lon, source_lat = source.getm.longitude, source.getm.latitude
if source_lon.ndim != 1:
raise Exception(f"Source longitude must be 1D but has shape {source_lon.shape}")
if source_lat.ndim != 1:
raise Exception(f"Source latitude must be 1D but has shape {source_lat.shape}")
imin = source_lon.values.searchsorted(minlon, side="right") - 1
imax = source_lon.values.searchsorted(maxlon, side="left") + 1
if source_lat.values[1] < source_lat.values[0]:
jmin = (
source_lat.size
- source_lat.values[::-1].searchsorted(maxlat, side="left")
- 1
)
jmax = (
source_lat.size
- source_lat.values[::-1].searchsorted(minlat, side="right")
+ 1
)
else:
jmin = source_lat.values.searchsorted(minlat, side="right") - 1
jmax = source_lat.values.searchsorted(maxlat, side="left") + 1
if verbose:
print(imin, imax, source_lon.values.size, jmin, jmax, source_lat.values.size)
assert (imin >= 0 and imax <= source_lon.values.size) or periodic_lon, (
f"Requested longitude section {minlon} - {maxlon} is not fully covered"
f" by available range {source_lon.values[0]} - {source_lon.values[-1]}"
)
assert jmin >= 0 and jmax <= source_lat.values.size, (
f"Requested latitude section {minlat} - {maxlat} is not fully covered"
f" by available range {source_lat.values[0]} - {source_lat.values[-1]}"
)
if require_2d and jmax - jmin == 1:
jmin, jmax = (jmin, jmax + 1) if jmin == 0 else (jmin - 1, jmax)
if require_2d and imax - imin == 1:
imin, imax = (imin, imax + 1) if imin == 0 else (imin - 1, imax)
add_left = imin < 0
add_right = imax >= source_lon.values.size
imin = max(imin, 0)
imax = min(imax, source_lon.values.size)
ilondim = source.dims.index(source_lon.dims[0])
ilatdim = source.dims.index(source_lat.dims[0])
shape = list(source.shape)
shape[ilondim] = imax - imin
shape[ilatdim] = jmax - jmin
center_source = tuple(
[
{ilondim: slice(imin, imax), ilatdim: slice(jmin, jmax)}.get(i, slice(None))
for i in range(len(shape))
]
)
center_target = [
{ilondim: slice(0, imax - imin), ilatdim: slice(0, jmax - jmin)}.get(
i, slice(None)
)
for i in range(len(shape))
]
target_lon = source_lon[center_source[ilondim]]
target_lat = source_lat[center_source[ilatdim]]
overlap = abs(source_lon.values[-1] - source_lon.values[0] - 360.0) < 1e-5
if verbose:
print(
f"periodic longitude? {periodic_lon} Overlap?"
f" {abs(source_lon.values[-1] - source_lon.values[0] - 360.0)} = {overlap}"
)
left_target = None
right_target = None
if add_left:
# Periodic domain and we need to read beyond left boundary
imin_left = source_lon.values.searchsorted(minlon + 360.0, side="right") - 1
left_source = tuple(
[
{ilondim: slice(imin_left, -1 if overlap else None)}.get(i, s)
for i, s in enumerate(center_source)
]
)
nleft = source_lon.values.size - imin_left + (-1 if overlap else 0)
if verbose:
print(f"adding {nleft} values on the left")
shape[ilondim] += nleft
left_target = tuple(
[{ilondim: slice(0, nleft)}.get(i, s) for i, s in enumerate(center_target)]
)
center_target[ilondim] = slice(nleft, nleft + imax - imin)
target_lon = xr.concat(
(source_lon[left_source[ilondim]] - 360.0, target_lon),
source_lon.dims[0],
combine_attrs="no_conflicts",
)
if add_right:
# Periodic domain and we need to read beyond right boundary
imax_right = source_lon.values.searchsorted(maxlon - 360.0, side="left") + 1
right_source = tuple(
[
{ilondim: slice(1 if overlap else 0, imax_right)}.get(i, s)
for i, s in enumerate(center_source)
]
)
nright = imax_right + (-1 if overlap else 0)
if verbose:
print(f"adding {nright} values on the right")
shape[ilondim] += nright
right_target = tuple(
[
{ilondim: slice(s.stop, None)}.get(i, s)
for i, s in enumerate(center_target)
]
)
target_lon = xr.concat(
(target_lon, source_lon[right_source[ilondim]] + 360.0),
source_lon.dims[0],
combine_attrs="no_conflicts",
)
center_target = tuple(center_target)
shape = tuple(shape)
if verbose:
print(f"final shape: {shape}")
lazyvar = Slice(
_as_lazyarray(source),
shape=shape,
passthrough=[i for i in range(len(shape)) if i not in (ilondim, ilatdim)],
)
lazyvar._slices.append((center_source, center_target))
if left_target:
lazyvar._slices.append((left_source, left_target))
if right_target:
lazyvar._slices.append((right_source, right_target))
coords = dict(source.coords.items())
coords[source_lon.name] = target_lon
coords[source_lat.name] = target_lat
return xr.DataArray(
lazyvar, dims=source.dims, coords=coords, attrs=source.attrs, name=lazyvar.name
)
[docs]
def concatenate_slices(
source: xr.DataArray, idim: int, slices: Iterable[slice], verbose=False
) -> xr.DataArray:
assert idim < source.ndim
assert all([isinstance(s, slice) for s in slices])
shape = list(source.shape)
shape[idim] = sum([s.stop - s.start for s in slices])
shape = tuple(shape)
if verbose:
print(f"final shape: {shape}")
istart = 0
strslices = ""
final_slices = []
for s in slices:
n = s.stop - s.start
source_slice = [slice(None)] * source.ndim
target_slice = [slice(None)] * source.ndim
source_slice[idim] = s
target_slice[idim] = slice(istart, istart + n)
strslices += f"[{s.start}:{s.stop}],"
final_slices.append((tuple(source_slice), tuple(target_slice)))
istart += n
assert istart == shape[idim]
lazyvar = Slice(
_as_lazyarray(source),
shape=shape,
passthrough=[i for i in range(len(shape)) if i != idim],
)
lazyvar._slices.extend(final_slices)
coords = {}
for name, c in source.coords.items():
if source.dims[idim] not in c.dims:
coords[name] = c
return xr.DataArray(
lazyvar, dims=source.dims, coords=coords, attrs=source.attrs, name=lazyvar.name
)
[docs]
class Transpose(UnaryOperator):
def __init__(self, a, axes: Iterable[int], **kwargs):
super().__init__(a, **kwargs)
self.axes = axes
self.oldaxes = list(axes)
for inew, iold in enumerate(self.axes):
self.oldaxes[iold] = inew
def __array__(self, dtype=None, copy=None) -> np.ndarray:
return np.asarray(self._source).transpose(self.axes)
def __getitem__(self, slices) -> np.ndarray:
newslices = list(self._finalize_slices(slices))
finalslices = [slice(None)] * len(newslices)
for inew, s in enumerate(slices):
if isinstance(s, (int, np.integer)):
newslices[inew] = slice(s, s + 1)
finalslices[inew] = 0
oldslices = tuple(newslices[inew] for inew in self.oldaxes)
return self._source[oldslices].transpose(self.axes)[tuple(finalslices)]
[docs]
def transpose(
source: xr.DataArray, axes: Optional[Iterable[int]] = None
) -> xr.DataArray:
if axes is None:
axes = range(source.ndim)[::-1]
dims = [source.dims[i] for i in axes]
shape = [source.shape[i] for i in axes]
lazyvar = Transpose(
_as_lazyarray(source),
axes,
shape=shape,
passthrough=range(source.ndim),
dtype=source.dtype,
)
coords = {}
for name, c in source.coords.items():
if c.ndim > 1:
newcdims = []
for d in dims:
if d in c.dims:
newcdims.append(d)
caxes = [c.dims.index(d) for d in newcdims]
coords[name] = transpose(c, caxes)
coords[name] = c
return xr.DataArray(
lazyvar,
dims=dims,
coords=coords,
attrs=source.attrs,
name=lazyvar.name,
)
[docs]
def isel(source: xr.DataArray, **indices) -> xr.DataArray:
"""Index named dimensions with integers, slice objects or integer arrays"""
advanced_indices = []
for dim in list(indices):
assert dim in source.dims, (
f"indexed dimension {dim} not used by source,"
f" which has dimensions {source.dims}"
)
if not isinstance(indices[dim], (int, slice)):
advanced_indices.append(source.dims.index(dim))
# indices[dim] = xr.Variable([source.dims[advanced_indices[0]]], np.asarray(indices[dim], dtype=np.intp))
indices[dim] = xr.Variable(
["__newdim"], np.asarray(indices[dim], dtype=np.intp)
)
# Final slices per dimension
slices = tuple([indices.get(dim, slice(None)) for dim in source.dims])
# Determine final shape
shape = []
dims = []
passthrough = {}
advanced_added = False
for i, (dim, slc, l) in enumerate(zip(source.dims, slices, source.shape)):
if i not in advanced_indices:
# Slice is integer or slice object. if integer, it will be sliced out
# so it does not contribute to the final shape
if isinstance(slc, slice):
start, stop, stride = slc.indices(l)
dims.append(dim)
passthrough[len(shape)] = i
shape.append(len(range(start, stop, stride)))
elif not advanced_added:
# First advanced slice. Add the shape produced by the broadcast combination
# of advanced indices
assert max(advanced_indices) - min(advanced_indices) + 1 == len(
advanced_indices
), "advanced indices must be side-by-side for now"
advanced_shapes = [indices[source.dims[i]].shape for i in advanced_indices]
for length in np.broadcast_shapes(*advanced_shapes):
dims.append(f"dim_{len(shape)}")
shape.append(length)
advanced_added = True
lazyvar = Slice(_as_lazyarray(source), shape=shape, passthrough=passthrough)
lazyvar._slices.append((slices, (slice(None),) * len(shape)))
coords = {}
for name, c in source.coords.items():
if all(dim not in c.dims for dim in indices):
coords[name] = c
return xr.DataArray(
lazyvar, dims=dims, coords=coords, attrs=source.attrs, name=lazyvar.name
)
[docs]
def horizontal_interpolation(
source: xr.DataArray,
x: xr.DataArray,
y: xr.DataArray,
*,
mask: Optional[npt.ArrayLike] = None,
xp: Union[str, xr.DataArray, None] = None,
yp: Union[str, xr.DataArray, None] = None,
) -> xr.DataArray:
"""Two-dimensional linear interpolation
For target coordinates that fall into cells with one or more masked corners,
nearest-neighbor interpolation is used to find the nearest unmasked point.
Args:
source: source variable
x: x-coordinates at which to evaluate the interpolated values
y: y-coordinates at which to evaluate the interpolated values
mask: mask for the source variable. If None, all points of the source
variable are assumed to be valid.
xp: x-coordinates of the source variable.
If None, the longitudes of the source variable are used.
yp: y-coordinates of the source variable.
If None, the latitudes of the source variable are used.
"""
if xp is None:
xp = source.getm.longitude
if xp is None:
raise Exception(
f"Variable {source.name} does not have a valid longitude coordinate."
)
elif isinstance(xp, str):
xp = source.coords[xp]
if yp is None:
yp = source.getm.latitude
if yp is None:
raise Exception(
f"Variable {source.name} does not have a valid latitude coordinate."
)
elif isinstance(yp, str):
yp = source.coords[yp]
assert xp.ndim == 1
assert yp.ndim == 1
assert np.isfinite(x).all(), f"Some target x-coordinates are non-finite: {x}"
assert np.isfinite(y).all(), f"Some target y-coordinates are non-finite: {y}"
x, y = np.broadcast_arrays(x, y)
ixdim = source.dims.index(xp.dims[0])
iydim = source.dims.index(yp.dims[0])
assert abs(ixdim - iydim) == 1, "x and y dimensions must be distinct and adjacent"
dimensions = {0: (), 1: (xp.dims[0],), 2: (yp.dims[0], xp.dims[-1])}[x.ndim]
shape = (
source.shape[: min(ixdim, iydim)]
+ x.shape
+ source.shape[max(ixdim, iydim) + 1 :]
)
kwargs = {"ndim_trailing": source.ndim - max(ixdim, iydim) - 1, "mask": mask}
if ixdim > iydim:
# Dimension order: y first, then x
ip = pygetm.util.interpolate.Linear2DGridInterpolator(y, x, yp, xp, **kwargs)
else:
# Dimension order: x first, then y
dimensions = dimensions[::-1]
ip = pygetm.util.interpolate.Linear2DGridInterpolator(x, y, xp, yp, **kwargs)
# Coordinates for the interpolated variable
x_name, y_name = xp.name, yp.name
if x_name in dimensions and x.ndim > 1:
x_name = x_name + "_"
if y_name in dimensions and y.ndim > 1:
y_name = y_name + "_"
coords = {k: v for k, v in source.coords.items() if k not in {xp.name, yp.name}}
coords[x_name] = xr.DataArray(x, dims=dimensions, name=x_name, attrs=xp.attrs)
coords[y_name] = xr.DataArray(y, dims=dimensions, name=y_name, attrs=yp.attrs)
# Dimensions of the interpolated variable
dims = (
source.dims[: min(ixdim, iydim)]
+ dimensions
+ source.dims[max(ixdim, iydim) + 1 :]
)
lazyvar = HorizontalInterpolation(
ip,
_as_lazyarray(source),
shape,
min(ixdim, iydim),
source.ndim - max(ixdim, iydim) - 1,
)
return xr.DataArray(
lazyvar, dims=dims, coords=coords, attrs=source.attrs, name=lazyvar.name
)
[docs]
class HorizontalInterpolation(UnaryOperator):
def __init__(
self,
ip: pygetm.util.interpolate.Linear2DGridInterpolator,
source: LazyArray,
shape: Iterable[int],
npre: int,
npost: int,
**kwargs,
):
super().__init__(source, shape=shape, dtype=float, **kwargs)
self._ip = ip
self.npre = npre
self.npost = npost
def __array__(self, dtype=None, copy=None) -> np.ndarray:
return self._ip(np.asarray(self._source))
def __getitem__(self, slices) -> np.ndarray:
src_slice, tgt_slice = [Ellipsis], [Ellipsis]
ntrailing_dim_removed = 0
for i, s in enumerate(self._finalize_slices(slices)):
if i < self.npre:
# prefixed dimension
src_slice.insert(i, s)
elif i >= self.ndim - self.npost:
# trailing dimension
if isinstance(s, (int, np.integer)):
ntrailing_dim_removed += 1
tgt_slice.append(0)
src_slice.append(s)
else:
assert (
isinstance(s, slice)
and (s.start is None or s.start == 0)
and (s.stop is None or s.stop == self.shape[i])
and s.step is None
), repr(s)
source = np.asarray(self._source[tuple(src_slice)])
source.shape = source.shape + (1,) * ntrailing_dim_removed
result = self._ip(source)
return result[tuple(tgt_slice)]
[docs]
def vertical_interpolation(
source: xr.DataArray, target_z: npt.ArrayLike, itargetdim: int = 0
) -> xr.DataArray:
source_z = source.getm.z
target_z = np.asarray(target_z)
if source_z is None:
raise Exception(
f"Variable {source.name} does not have a valid depth coordinate."
)
assert source_z.ndim == 1
izdim = source.dims.index(source_z.dims[0])
# assert source.ndim - izdim == target_z.ndim
# assert source.shape[izdim + 1:izdim + 3] == target_z.shape[1:], f'{source.shape[izdim + 1:izdim + 3]} vs {target_z.shape[1:]}'
target2sourcedim = {}
isourcedim = 0
for i, l in enumerate(target_z.shape):
if i == itargetdim:
isourcedim = izdim
else:
while (
isourcedim != izdim
and isourcedim < source.ndim
and l != source.shape[isourcedim]
):
isourcedim += 1
assert isourcedim != izdim, (
f"Dimension with length {l} should precede depth dimension {izdim}"
f" in {source.name}, which has shape {source.shape}"
)
assert isourcedim < source.ndim, (
f"Dimension with length {l} expected after depth dimension {izdim}"
f" in {source.name}, which has shape {source.shape}"
)
target2sourcedim[i] = isourcedim
isourcedim += 1
coords = {}
for n, c in source.coords.items():
if n == source_z.name:
coords[n + "_"] = (
[source.dims[target2sourcedim[i]] for i in range(target_z.ndim)],
target_z,
)
else:
coords[n] = c
lazyvar = VerticalInterpolation(
_as_lazyarray(source), target_z, izdim, source_z.values, itargetdim
)
return xr.DataArray(
lazyvar, dims=source.dims, coords=coords, attrs=source.attrs, name=lazyvar.name
)
[docs]
class VerticalInterpolation(UnaryOperator):
def __init__(
self,
source: LazyArray,
z: np.ndarray,
izdim: int,
source_z: np.ndarray,
axis: int = 0,
**kwargs,
):
self.izdim = izdim
passthrough = [idim for idim in range(source.ndim) if idim != self.izdim]
shape = list(source.shape)
shape[self.izdim] = z.shape[axis]
super().__init__(
source, shape=shape, dtype=float, passthrough=passthrough, **kwargs
)
self.z = z
self.axis = axis
self.source_z = source_z
if (self.source_z >= 0.0).all():
self.source_z = -self.source_z
[docs]
def apply(self, source: np.ndarray, dtype=None) -> np.ndarray:
return pygetm.util.interpolate.interp_1d(
self.z, self.source_z, source, axis=self.axis
)
[docs]
def temporal_interpolation(
source: xr.DataArray,
climatology: bool = False,
comm: MPI.Comm = MPI.COMM_SELF,
logger: Optional[logging.Logger] = None,
) -> xr.DataArray:
time_coord = source.getm.time
assert time_coord is not None, "No time coordinate found"
itimedim = source.dims.index(time_coord.dims[0])
lazyvar = TemporalInterpolation(
_as_lazyarray(source),
itimedim,
time_coord.values,
climatology,
comm=comm,
logger=logger,
)
dims = [d for i, d in enumerate(source.dims) if i != lazyvar._itimedim]
coords = {time_coord.dims[0]: lazyvar._timecoord}
for n, c in source.coords.items():
if time_coord.dims[0] not in c.dims:
coords[n] = c
return xr.DataArray(
lazyvar, dims=dims, coords=coords, attrs=source.attrs, name=lazyvar.name
)
def _as_lazyarray(array: xr.DataArray) -> LazyArray:
variable = array.variable
if isinstance(variable._data, LazyArray):
return variable._data
else:
name = array.name
if name is None:
name = object.__repr__(variable._data)
if "source" in array.encoding:
name += " from " + array.encoding["source"]
return Wrap(variable, name=name)
[docs]
class Cache(UnaryOperator):
def __init__(self, source: LazyArray, idim: int, n: int):
super().__init__(source)
self._cache: Optional[np.ndarray] = None
self.istart = 0
self.n = n
self.idim = idim
def __getitem__(self, slices) -> np.ndarray:
slices = list(self._finalize_slices(slices))
i = slices[self.idim]
if not isinstance(i, (int, np.integer)):
return self._source[tuple(slices)]
if self._cache is None or i < self.istart or i >= self.istart + self.n:
self.istart = i
slices[self.idim] = slice(i, i + self.n)
self._cache = self._source[tuple(slices)]
slices[self.idim] = i - self.istart
return self._cache[tuple(slices)]
def __array__(self, dtype=None, copy=None) -> np.ndarray:
return np.asarray(self._source, dtype=dtype)
[docs]
class TemporalInterpolation(UnaryOperator):
__slots__ = (
"_current",
"_itimedim",
"_numnow",
"_numnext",
"_slope",
"_inext",
"_next",
"_slices",
"climatology",
"_year",
"_timevalues",
)
MAX_CACHE_SIZE = 0
NTIME_CACHE = 1
def __init__(
self,
source: LazyArray,
itimedim: int,
times: npt.ArrayLike,
climatology: bool,
dtype: npt.DTypeLike = float,
comm: MPI.Comm = MPI.COMM_SELF,
logger: Optional[logging.Logger] = None,
**kwargs,
):
if self.NTIME_CACHE > 1:
source = Cache(source, itimedim, self.NTIME_CACHE)
shape = list(source.shape)
self._itimedim = itimedim
ntime = shape[self._itimedim]
if ntime <= 1:
raise Exception(
f"Cannot interpolate {source.name} in time because"
f" its time dimension has length {ntime}."
)
shape.pop(self._itimedim)
super().__init__(source, shape=shape, dtype=dtype, **kwargs)
self._current = np.empty(shape, dtype=self.dtype)
self.times = np.asarray(times)
self._timecoord = xr.DataArray(self.times[0])
self._timevalues = self._timecoord.values
self._numnow = None
self._numnext = 0.0
self._slope = 0.0
self._inext = -1
self._next = 0.0
self._slices: list[Union[int, slice]] = [slice(None)] * source.ndim
self.climatology = climatology
self._year = self.times[0].year
self._cache = {}
self._use_cache = False
if climatology and not all(time.year == self._year for time in self.times):
raise Exception(
f"{self._source_name} cannot be used as climatology because"
" it spans more than one calendar year"
f" ({self.times[0]} - {self.times[-1]})."
)
if climatology and self.MAX_CACHE_SIZE > 0:
memory = np.asarray(source.size * (self.dtype.itemsize / 1024 / 1024))
# Ensure all subdomains take same caching decision
# In part because performance generally does not improve until all
# subdomains cache, but foremost because underlying LazyArrays may
# use collective MPI operations in which all subdomains MUST participate
comm.Allreduce(MPI.IN_PLACE, memory, MPI.MAX)
self._use_cache = memory < self.MAX_CACHE_SIZE
if logger:
if self._use_cache:
logger.info(
f"{source.name} is a climatology that will be cached in memory."
f" This requires {memory:.1f} MB, which is less than the"
f" maximum allowed cache size of {self.MAX_CACHE_SIZE} MB"
)
else:
logger.info(
f"Not caching {source.name} because this would require"
f" {memory:.1f} MB, which is more than the maximum allowed"
f" cache size of {self.MAX_CACHE_SIZE} MB"
)
def __array__(self, dtype=None, copy=None) -> np.ndarray:
return self._current
def __getitem__(self, slices) -> np.ndarray:
return self._current[slices]
[docs]
def is_time_varying(self) -> bool:
return True
[docs]
def update(
self, time: cftime.datetime, numtime: Optional[np.longdouble] = None
) -> bool:
if numtime is None:
numtime = time.toordinal(fractional=True)
if self._numnow is None:
# First call to update
self._start(time)
elif numtime <= self._numnow:
# Subsequent call to update, but time has not increased
# If equal to previous time, we are done. If smaller, rewind
if numtime == self._numnow:
return False
self._start(time)
while self._numnext < numtime:
self._move_to_next(time)
# Do linear interpolation
if numtime == self._numnext:
self._current[...] = self._next
else:
np.multiply(self._slope, numtime - self._numnext, out=self._current)
self._current += self._next
# Save current time
self._numnow = numtime
self._timevalues[...] = time
return True
def _start(self, time: cftime.datetime):
if time.calendar != self.times[0].calendar:
try:
time = time.change_calendar(self.times[0].calendar)
except ValueError:
raise Exception(
f"Calendar {self.times[0].calendar} used by {self._source_name}"
f" is incompatible with simulation calendar {time.calendar}."
)
# Find the time index (_inext) just before the left bound of the window we need.
# We will subsequently call _move_to_next twice to load the actual window.
# If a time in the series exactly matches the requested time, this will be the
# right bound of the time window (hence side="left" below), except if that
# is the very first time in the time series. Then it will be the left bound.
if self.climatology:
clim_time = time.replace(year=self._year)
self._inext = self.times.searchsorted(clim_time, side="left") - 2
self._year = time.year
if self._inext < -1:
self._inext += self.times.size
self._year -= 1
else:
# Make sure the time series does not start after the requested time.
if time < self.times[0]:
raise Exception(
f"Cannot interpolate {self._source_name} to value at {time},"
f" because time series starts only at {self.times[0]}."
)
self._inext = max(-1, self.times.searchsorted(time, side="left") - 2)
# Load left and right bound of the window encompassing the current time.
# After that, all information for linear interpolation (_next, _slope)
# is available.
self._move_to_next(time)
self._move_to_next(time)
def _move_to_next(self, time: cftime.datetime):
# Move to next record
self._inext += 1
if self._inext == self.times.size:
if self.climatology:
self._inext = 0
self._year += 1
else:
raise Exception(
f"Cannot interpolate {self._source_name} to value at {time}"
f" because end of time series was reached ({self.times[-1]})."
)
old = self._next
numold = self._numnext
if self._inext in self._cache:
self._next = self._cache[self._inext]
else:
self._slices[self._itimedim] = self._inext
self._next = np.asarray(self._source[tuple(self._slices)], dtype=self.dtype)
if self._use_cache:
self._cache[self._inext] = self._next
next_time = self.times[self._inext]
if self.climatology:
next_time = next_time.replace(year=self._year)
self._numnext = next_time.toordinal(fractional=True)
self._slope = (self._next - old) / (self._numnext - numold)
[docs]
def slicespec2string(s: Union[tuple, slice, int]) -> str:
if isinstance(s, slice):
start = "" if s.start is None else f"{s.start}"
stop = "" if s.stop is None else f"{s.stop}"
if s.step in (None, 1):
return f"{start}:{stop}"
return f"{start}:{stop}:{s.step}"
elif isinstance(s, tuple):
return ",".join([slicespec2string(item) for item in s])
return f"{s!r}"
[docs]
def debug_nc_reads(logger: Optional[logging.Logger] = None):
"""Hook into :mod:`xarray` so that every read from a NetCDF file is
written to the log.
"""
import xarray.backends.netCDF4_
if logger is None:
logger = logging.getLogger("pygetm.input")
logger.setLevel(logging.DEBUG)
class NetCDF4ArrayWrapper2(xarray.backends.netCDF4_.NetCDF4ArrayWrapper):
__slots__ = ()
def _getitem(self, key):
logger.debug(
f"Reading {self.variable_name}[{slicespec2string(key)}]"
f" from {self.datastore._filename}"
)
return super()._getitem(key)
xarray.backends.netCDF4_.NetCDF4ArrayWrapper = NetCDF4ArrayWrapper2
[docs]
class OnGrid(enum.Enum):
#: Grids do not match. Spatially explicit data will require horizontal
#: and - if vertically resolved - vertical interpolation.
NONE = enum.auto()
#: Horizontal grid matches, but vertical does not.
#: Vertically resolved data will require vertical interpolation.
HORIZONTAL = enum.auto()
#: Horizontal and vertical grids match
ALL = enum.auto()