Source code for pygetm.input

from typing import Callable, Iterable, Mapping, Union, Optional, TYPE_CHECKING
import glob
import numbers
import logging
import enum
import functools
import itertools
import re
import os
from pathlib import Path

import numpy as np
import numpy.typing as npt
import numpy.lib.mixins
import xarray as xr
import cftime

import pygetm.util.interpolate
from pygetm.constants import CENTERS, TimeVarying
from pygetm.parallel import MPI

if TYPE_CHECKING:
    import pygetm.core

LATITUDE_UNITS = (
    "degrees_north",
    "degree_north",
    "degree_N",
    "degrees_N",
    "degreeN",
    "degreesN",
)

LONGITUDE_UNITS = (
    "degrees_east",
    "degree_east",
    "degree_E",
    "degrees_E",
    "degreeE",
    "degreesE",
)

Z_STANDARD_NAMES = (
    "height",
    "height_above_mean_sea_level",
    "depth",
    "depth_below_geoid",
)


[docs] @xr.register_dataarray_accessor("getm") class GETMAccessor: def __init__(self, xarray_obj: xr.DataArray): self._obj = xarray_obj @property def longitude(self) -> Optional[xr.DataArray]: return self.coordinates.get("longitude") @property def latitude(self) -> Optional[xr.DataArray]: return self.coordinates.get("latitude") @property def z(self) -> Optional[xr.DataArray]: return self.coordinates.get("z") @property def time(self) -> Optional[xr.DataArray]: return self.coordinates.get("time")
[docs] @functools.cached_property def coordinates(self) -> Mapping[str, xr.DataArray]: prescribed = self._obj.encoding.get("coordinates", "").split() def priority(name: str) -> int: if name in prescribed: return 2 if name in self._obj.indexes: return 1 return 0 _coordinates = {} for name in sorted(self._obj.coords, key=priority): coord = self._obj.coords[name] units = coord.attrs.get("units") standard_name = coord.attrs.get("standard_name") if standard_name in ("latitude", "longitude"): _coordinates[standard_name] = coord elif ( coord.attrs.get("positive") or standard_name in Z_STANDARD_NAMES or coord.attrs.get("axis") == "Z" ): _coordinates["z"] = coord elif units in LATITUDE_UNITS: _coordinates["latitude"] = coord elif units in LONGITUDE_UNITS: _coordinates["longitude"] = coord elif coord.size > 0 and isinstance(coord.values.flat[0], cftime.datetime): _coordinates["time"] = coord elif name == "zax": _coordinates["z"] = coord return _coordinates
open_nc_files = [] def _open(path: Union[str, os.PathLike[str]], preprocess=None, **kwargs): key = (path, preprocess, kwargs.copy()) for k, ds in open_nc_files: if k == key: return ds ds = xr.open_dataset(path, **kwargs) if preprocess: ds = preprocess(ds) open_nc_files.append((key, ds)) return ds try: DEFAULT_TIME_DECODER = xr.coders.CFDatetimeCoder(use_cftime=True) except AttributeError: DEFAULT_TIME_DECODER = None
[docs] def from_nc( paths: Union[str, os.PathLike[str], Iterable[Union[str, os.PathLike[str]]]], name: str, preprocess: Optional[Callable[[xr.Dataset], xr.Dataset]] = None, **kwargs, ) -> xr.DataArray: """Obtain a variable from one or more NetCDF files that can be used as value provided to :meth:`InputManager.add` and :meth:`pygetm.core.Array.set`. Args: paths: single file path, a pathname pattern containing `*` and/or `?`, or a sequence of file paths. If multiple paths are provided (or the pattern resolves to multiple valid path names), the files will be concatenated along their time dimension. preprocess: function that transforms the :class:`xarray.Dataset` opened for every path provided. This can be used to modify the datasets before concatenation in time is attempted, for instance, to cut off time indices that overlap between files. **kwargs: additional keyword arguments to be passed to :func:`xarray.open_dataset` """ if DEFAULT_TIME_DECODER is not None: kwargs.setdefault("decode_times", DEFAULT_TIME_DECODER) else: # xarray < 2025.01.1 kwargs.setdefault("decode_times", True) kwargs["use_cftime"] = True kwargs["cache"] = False if isinstance(paths, (str, os.PathLike)): # This is a single item (string or PathLike), not a list of such items. # Check if it is a URL or a pattern # https://github.com/pydata/xarray/blob/40c27d19d169ccf1c469255c6c6da327f5822d01/xarray/core/utils.py#L692C17-L692C63 if isinstance(paths, str) and not re.match(r"[a-z][a-z0-9]*(\://|\:\:)", paths): # Not a URL, but a file path or glob pattern. Cast to list of valid paths. pattern = paths paths = glob.glob(pattern) if not paths: raise Exception(f"No files found matching {pattern!r}") paths = map(Path, paths) else: # A URL or a single file path (PathLike) if isinstance(paths, os.PathLike): if not os.path.exists(paths): raise FileNotFoundError(f"File not found: {paths}") paths = (paths,) arrays = [] for path in paths: ds = _open(path, preprocess, **kwargs) array = ds[name] # Note: we wrap the netCDF array ourselves, in order to support lazy operators # (e.g., add, multiply) arrays.append(wrap(array, name=f'from_nc("{path}", {name!r})')) if len(arrays) == 1: return arrays[0] else: assert all(array.getm.time is not None for array in arrays) return xr.concat( sorted(arrays, key=lambda a: a.getm.time.values.flat[0]), dim=arrays[0].getm.time.dims[0], coords="minimal", combine_attrs="drop_conflicts", )
[docs] def wrap(array: xr.DataArray, name: Optional[str] = None) -> xr.DataArray: if name is None: name = array.name or "wrapped_array" lazyvar = Wrap(array.variable, name=name) wrapped_array = xr.DataArray( lazyvar, dims=array.dims, coords=array.coords, attrs=array.attrs, name=lazyvar.name, ) wrapped_array.encoding.update(array.encoding) return wrapped_array
[docs] class LazyArray(numpy.lib.mixins.NDArrayOperatorsMixin): def __init__(self, shape: Iterable[int], dtype: npt.DTypeLike, name: str): self.shape = tuple(shape) self.ndim = len(self.shape) self.dtype = np.dtype(dtype) self.name = name @property def size(self) -> int: return int(np.prod(self.shape))
[docs] def update(self, time: cftime.datetime, numtime: np.longdouble) -> bool: return False
[docs] def astype(self, dtype, **kwargs) -> np.ndarray: if dtype == self.dtype: return self return self.__array__(dtype)
def __array_function__(self, func, types, args, kwargs): if func == np.result_type: args = tuple(x.dtype if isinstance(x, LazyArray) else x for x in args) return np.result_type(*args) elif func == np.concatenate: return Concatenate(*args, **kwargs) elif func == np.ndim: return self.ndim elif func == np.shape: return self.shape elif func == np.size: return self.size args = tuple(np.asarray(x) if isinstance(x, LazyArray) else x for x in args) kwargs = dict( (k, np.asarray(v)) if isinstance(v, LazyArray) else (k, v) for (k, v) in kwargs.items() ) return func(*args, **kwargs) def __array_ufunc__(self, ufunc, method: str, *args, **kwargs): if method != "__call__": return NotImplemented if "out" in kwargs: return NotImplemented COMPATIBLE_TYPES = (np.ndarray, numbers.Number, LazyArray, xr.Variable) for x in args: if not isinstance(x, COMPATIBLE_TYPES): return NotImplemented return UFunc(ufunc, method, *args, **kwargs) def __array__(self, dtype=None, copy=None) -> np.ndarray: raise NotImplementedError def __getitem__(self, slices) -> np.ndarray: return self.__array__()[slices]
[docs] def is_time_varying(self) -> bool: return False
def _finalize_slices(self, slices: tuple): assert isinstance(slices, tuple) for i, s in enumerate(slices): if s is Ellipsis: slices = ( slices[:i] + (slice(None),) * (self.ndim + 1 - len(slices)) + slices[i + 1 :] ) break assert len(slices) == self.ndim return slices
[docs] class Operator(LazyArray): _operator_name = None def __init__( self, *args, passthrough=(), dtype: npt.DTypeLike = None, shape: Optional[tuple[int]] = None, name: Optional[str] = None, kwslice: Iterable[str] = (), **kwargs, ): def _repr(a) -> str: if isinstance(a, LazyArray): return a.name elif np.ndim(a) == 0: return str(a) # Other datatype, typically xarray.Variable. # Do not call object's custom str/repr, as that will cause evaluation # (e.g. read from file) of the entire array return f"{type(a).__name__}(shape={np.shape(a)}, dtype={np.result_type(a)})" # Unpack unnamed arguments self.args = [] self.arg_names = [] for arg in args: assert isinstance( arg, (np.ndarray, numbers.Number, LazyArray, xr.Variable) ), f"Argument has unknown type {type(arg)}" # Unpack to LazyArray if possible if isinstance(arg, xr.Variable) and isinstance(arg._data, LazyArray): arg = arg._data self.arg_names.append(_repr(arg)) # If this is a Wrap, unwrap # (the wrapping was only for ufunc support) if isinstance(arg, Wrap): arg = arg._source self.args.append(arg) self._lazy_args = [arg for arg in self.args if isinstance(arg, LazyArray)] # Store keyword arguments as-is (no unpacking) self.kwargs = kwargs kwarg_names = {k: _repr(v) for k, v in kwargs.items()} # Infer shape from positional arguments if not provided self._sliced_kwargs = [] if shape is None: shapes = [np.shape(input) for input in args] shape = np.broadcast_shapes(*shapes) def broadcast(a): if np.ndim(a) != 0 and np.shape(a) != shape: return np.broadcast_to(a, shape) return a self.args = tuple(map(broadcast, self.args)) for k in kwslice: if k in self.kwargs: self.kwargs[k] = broadcast(self.kwargs[k]) self._sliced_kwargs.append(k) if dtype is None: dtype = np.result_type(*self.args) # Process dimensions for which we can passthrough slices to inputs # This can be True (= all dimensions), an iterable, or a dictionary mapping # sliced dimensions to input dimensions (if the current operator adds or # removes dimensions) if passthrough is True: passthrough = range(len(shape)) if not isinstance(passthrough, dict): passthrough = {i: i for i in passthrough} self.passthrough = passthrough assert all([isinstance(dim, int) for dim in self.passthrough]), ( f"Invalid passthrough: {self.passthrough}." " All entries should be of type int" ) # Determine which arguments get slices passed through self._sliced_args = [np.ndim(a) != 0 for a in self.args] # Generate a name for the variable if not provided if name is None: operator_name = self._operator_name or self.__class__.__name__ strargs = ", ".join(self.arg_names) strkwargs = "".join(f", {k}={v}" for (k, v) in kwarg_names.items()) name = f"{operator_name}({strargs}{strkwargs})" super().__init__(shape, dtype, name)
[docs] def update(self, *args) -> bool: return any([arg.update(*args) for arg in self._lazy_args])
[docs] def is_time_varying(self) -> bool: return any(input.is_time_varying() for input in self._lazy_args)
def __getitem__(self, slices) -> np.ndarray: preslices, postslices = [], [] for i, slc in enumerate(self._finalize_slices(slices)): if i in self.passthrough: preslices.append(slc) if not isinstance(slc, (int, np.integer)): postslices.append(slice(None)) else: preslices.append(slice(None)) postslices.append(slc) preslices = tuple(preslices) args = [ np.asarray(arg[preslices] if s else arg) for arg, s in zip(self.args, self._sliced_args) ] kwargs = self.kwargs if self._sliced_kwargs: kw_override = { k: np.asarray(kwargs[k][preslices]) for k in self._sliced_kwargs } kwargs = {**kwargs, **kw_override} return self.apply(*args, **kwargs)[tuple(postslices)]
[docs] def apply(self, *args: np.ndarray, dtype=None, **kwargs) -> np.ndarray: raise NotImplementedError
def __array__(self, dtype=None, copy=None) -> np.ndarray: args = map(np.asarray, self.args) kwargs = self.kwargs if self._sliced_kwargs: kw_override = {k: np.asarray(kwargs[k]) for k in self._sliced_kwargs} kwargs = {**kwargs, **kw_override} return self.apply(*args, dtype=dtype, **kwargs)
[docs] class UnaryOperator(Operator): def __init__(self, arg, **kwargs): super().__init__(arg, **kwargs) self._source = self.args[0] self._source_name = self.arg_names[0]
[docs] class UFunc(Operator): def __init__(self, ufunc, method: str, *args, **kwargs): self._operator_name = ufunc.__name__ super().__init__(*args, passthrough=True, kwslice=("where",), **kwargs) self.ufunc = getattr(ufunc, method)
[docs] def apply(self, *args: np.ndarray, **kwargs) -> np.ndarray: return self.ufunc(*args, **kwargs)
[docs] class Wrap(UnaryOperator): def __init__(self, source: xr.Variable, name: str): assert isinstance(source, xr.Variable) super().__init__(source, passthrough=True, name=name)
[docs] def apply(self, source: np.ndarray, dtype=None) -> np.ndarray: return source
[docs] def update(self, *args) -> bool: return False
[docs] class Slice(UnaryOperator): def __init__(self, source, shape: tuple[int], passthrough): super().__init__(source, shape=shape, passthrough=passthrough) self._slices = [] self.passthrough_own_slices = True def __array__(self, dtype=None, copy=None) -> np.ndarray: data = np.empty(self.shape, dtype or self.dtype) for src_slice, tgt_slice in self._slices: if self.passthrough_own_slices: data[tgt_slice] = self._source[src_slice] else: data[tgt_slice] = np.asarray(self._source, dtype=dtype)[src_slice] return data def __getitem__(self, slices) -> np.ndarray: slices = self._finalize_slices(slices) shape = [] keep = [ not (i in self.passthrough and isinstance(s, (int, np.integer))) for i, s in enumerate(slices) ] for i, (l, s, k) in enumerate(zip(self.shape, slices, keep)): if not k: # This dimension will be sliced out continue assert isinstance(s, slice), ( f"Dimension {i} has unsupported slice type {type(s)} with value {s!r}." f" Passthrough: {list(self.passthrough)}" ) start, stop, step = s.indices(l) assert i in self.passthrough or (start == 0 and stop == l and step == 1), ( f"invalid slice for dimension {i} with length {l}:" f" {start}:{stop}:{step}" ) shape.append(len(range(start, stop, step))) data = np.empty(shape, self.dtype) for src_slice, tgt_slice in self._slices: if self.passthrough_own_slices: # Forward our own slices to source array src_slice = list(src_slice) else: # Apply our own slices only after retrieving data from source array # Only passed-through slices provided as argument are passed to source mid_slice = tuple(itertools.compress(src_slice, keep)) src_slice = [slice(None)] * len(src_slice) # Forward passed-through slices provided as argument to source array for iout, iin in self.passthrough.items(): if slices[iout] != slice(None): assert src_slice[iin] == slice(None), "Merging slices not supported" src_slice[iin] = slices[iout] tgt_slice = tuple(itertools.compress(tgt_slice, keep)) values = self._source[tuple(src_slice)] if not self.passthrough_own_slices: values = values[mid_slice] data[tgt_slice] = values return data
[docs] class Concatenate(Operator): def __init__(self, arrays, axis: int = 0, **kwargs): shape = list(arrays[0].shape) for array in arrays[1:]: shape[axis] += array.shape[axis] assert all(array.shape[i] == l for i, l in enumerate(shape) if i != axis) self.axis = axis super().__init__(*arrays, shape=shape, **kwargs) def __array__(self, dtype=None, copy=None) -> np.ndarray: arrays = [np.asarray(array, dtype=dtype) for array in self.args] return np.concatenate(arrays, axis=self.axis) def __getitem__(self, slices) -> np.ndarray: slices = list(self._finalize_slices(slices)) axis = self.axis if isinstance(slices[axis], (int, np.integer)): if slices[axis] < 0: slices[axis] += self.shape[axis] assert slices[axis] >= 0 for array in self.args: if slices[axis] < array.shape[axis]: return np.asarray(array[tuple(slices)]) slices[axis] -= array.shape[axis] raise IndexError else: assert isinstance( slices[axis], slice ), f"Invalid slice for concatenation axis {axis}: {slices[axis]!r}" outaxis = axis for s in slices[:axis]: if isinstance(s, (int, np.integer)): # This dimension precedes the axis over which we concatenate # and it will be sliced out from the source arrays # Therefore, the axis over which we concatenate decreases by 1 outaxis -= 1 start, stop, step = slices[axis].indices(self.shape[axis]) assert step > 0 arrays = [] for array in self.args: if start < array.shape[axis]: slices[axis] = slice(start, stop, step) current = np.asarray(array[tuple(slices)]) start += current.shape[outaxis] * step arrays.append(current) start -= array.shape[axis] stop -= array.shape[axis] if stop <= 0: break return np.concatenate(arrays, axis=outaxis)
[docs] def limit_region( source: xr.DataArray, minlon: float, maxlon: float, minlat: float, maxlat: float, periodic_lon: bool = False, verbose: bool = False, require_2d: bool = True, ) -> xr.DataArray: if not np.isfinite(minlon) or not np.isfinite(maxlon): raise Exception(f"Longitude range {minlon} - {maxlon} is not valid") if not np.isfinite(minlat) or not np.isfinite(maxlat): raise Exception(f"Latitude range {minlat} - {maxlat} is not valid") if minlon > maxlon: raise Exception( f"Invalid longitude range: maximum {maxlon} must be >= minimum {minlon}." ) if minlat > maxlat: raise Exception( f"Invalid latitude range: maximum {maxlat} must be >= minimum {minlat}." ) source_lon, source_lat = source.getm.longitude, source.getm.latitude if source_lon.ndim != 1: raise Exception(f"Source longitude must be 1D but has shape {source_lon.shape}") if source_lat.ndim != 1: raise Exception(f"Source latitude must be 1D but has shape {source_lat.shape}") imin = source_lon.values.searchsorted(minlon, side="right") - 1 imax = source_lon.values.searchsorted(maxlon, side="left") + 1 if source_lat.values[1] < source_lat.values[0]: jmin = ( source_lat.size - source_lat.values[::-1].searchsorted(maxlat, side="left") - 1 ) jmax = ( source_lat.size - source_lat.values[::-1].searchsorted(minlat, side="right") + 1 ) else: jmin = source_lat.values.searchsorted(minlat, side="right") - 1 jmax = source_lat.values.searchsorted(maxlat, side="left") + 1 if verbose: print(imin, imax, source_lon.values.size, jmin, jmax, source_lat.values.size) assert (imin >= 0 and imax <= source_lon.values.size) or periodic_lon, ( f"Requested longitude section {minlon} - {maxlon} is not fully covered" f" by available range {source_lon.values[0]} - {source_lon.values[-1]}" ) assert jmin >= 0 and jmax <= source_lat.values.size, ( f"Requested latitude section {minlat} - {maxlat} is not fully covered" f" by available range {source_lat.values[0]} - {source_lat.values[-1]}" ) if require_2d and jmax - jmin == 1: jmin, jmax = (jmin, jmax + 1) if jmin == 0 else (jmin - 1, jmax) if require_2d and imax - imin == 1: imin, imax = (imin, imax + 1) if imin == 0 else (imin - 1, imax) add_left = imin < 0 add_right = imax >= source_lon.values.size imin = max(imin, 0) imax = min(imax, source_lon.values.size) ilondim = source.dims.index(source_lon.dims[0]) ilatdim = source.dims.index(source_lat.dims[0]) shape = list(source.shape) shape[ilondim] = imax - imin shape[ilatdim] = jmax - jmin center_source = tuple( [ {ilondim: slice(imin, imax), ilatdim: slice(jmin, jmax)}.get(i, slice(None)) for i in range(len(shape)) ] ) center_target = [ {ilondim: slice(0, imax - imin), ilatdim: slice(0, jmax - jmin)}.get( i, slice(None) ) for i in range(len(shape)) ] target_lon = source_lon[center_source[ilondim]] target_lat = source_lat[center_source[ilatdim]] overlap = abs(source_lon.values[-1] - source_lon.values[0] - 360.0) < 1e-5 if verbose: print( f"periodic longitude? {periodic_lon} Overlap?" f" {abs(source_lon.values[-1] - source_lon.values[0] - 360.0)} = {overlap}" ) left_target = None right_target = None if add_left: # Periodic domain and we need to read beyond left boundary imin_left = source_lon.values.searchsorted(minlon + 360.0, side="right") - 1 left_source = tuple( [ {ilondim: slice(imin_left, -1 if overlap else None)}.get(i, s) for i, s in enumerate(center_source) ] ) nleft = source_lon.values.size - imin_left + (-1 if overlap else 0) if verbose: print(f"adding {nleft} values on the left") shape[ilondim] += nleft left_target = tuple( [{ilondim: slice(0, nleft)}.get(i, s) for i, s in enumerate(center_target)] ) center_target[ilondim] = slice(nleft, nleft + imax - imin) target_lon = xr.concat( (source_lon[left_source[ilondim]] - 360.0, target_lon), source_lon.dims[0], combine_attrs="no_conflicts", ) if add_right: # Periodic domain and we need to read beyond right boundary imax_right = source_lon.values.searchsorted(maxlon - 360.0, side="left") + 1 right_source = tuple( [ {ilondim: slice(1 if overlap else 0, imax_right)}.get(i, s) for i, s in enumerate(center_source) ] ) nright = imax_right + (-1 if overlap else 0) if verbose: print(f"adding {nright} values on the right") shape[ilondim] += nright right_target = tuple( [ {ilondim: slice(s.stop, None)}.get(i, s) for i, s in enumerate(center_target) ] ) target_lon = xr.concat( (target_lon, source_lon[right_source[ilondim]] + 360.0), source_lon.dims[0], combine_attrs="no_conflicts", ) center_target = tuple(center_target) shape = tuple(shape) if verbose: print(f"final shape: {shape}") lazyvar = Slice( _as_lazyarray(source), shape=shape, passthrough=[i for i in range(len(shape)) if i not in (ilondim, ilatdim)], ) lazyvar._slices.append((center_source, center_target)) if left_target: lazyvar._slices.append((left_source, left_target)) if right_target: lazyvar._slices.append((right_source, right_target)) coords = dict(source.coords.items()) coords[source_lon.name] = target_lon coords[source_lat.name] = target_lat return xr.DataArray( lazyvar, dims=source.dims, coords=coords, attrs=source.attrs, name=lazyvar.name )
[docs] def concatenate_slices( source: xr.DataArray, idim: int, slices: Iterable[slice], verbose=False ) -> xr.DataArray: assert idim < source.ndim assert all([isinstance(s, slice) for s in slices]) shape = list(source.shape) shape[idim] = sum([s.stop - s.start for s in slices]) shape = tuple(shape) if verbose: print(f"final shape: {shape}") istart = 0 strslices = "" final_slices = [] for s in slices: n = s.stop - s.start source_slice = [slice(None)] * source.ndim target_slice = [slice(None)] * source.ndim source_slice[idim] = s target_slice[idim] = slice(istart, istart + n) strslices += f"[{s.start}:{s.stop}]," final_slices.append((tuple(source_slice), tuple(target_slice))) istart += n assert istart == shape[idim] lazyvar = Slice( _as_lazyarray(source), shape=shape, passthrough=[i for i in range(len(shape)) if i != idim], ) lazyvar._slices.extend(final_slices) coords = {} for name, c in source.coords.items(): if source.dims[idim] not in c.dims: coords[name] = c return xr.DataArray( lazyvar, dims=source.dims, coords=coords, attrs=source.attrs, name=lazyvar.name )
[docs] class Transpose(UnaryOperator): def __init__(self, a, axes: Iterable[int], **kwargs): super().__init__(a, **kwargs) self.axes = axes self.oldaxes = list(axes) for inew, iold in enumerate(self.axes): self.oldaxes[iold] = inew def __array__(self, dtype=None, copy=None) -> np.ndarray: return np.asarray(self._source).transpose(self.axes) def __getitem__(self, slices) -> np.ndarray: newslices = list(self._finalize_slices(slices)) finalslices = [slice(None)] * len(newslices) for inew, s in enumerate(slices): if isinstance(s, (int, np.integer)): newslices[inew] = slice(s, s + 1) finalslices[inew] = 0 oldslices = tuple(newslices[inew] for inew in self.oldaxes) return self._source[oldslices].transpose(self.axes)[tuple(finalslices)]
[docs] def transpose( source: xr.DataArray, axes: Optional[Iterable[int]] = None ) -> xr.DataArray: if axes is None: axes = range(source.ndim)[::-1] dims = [source.dims[i] for i in axes] shape = [source.shape[i] for i in axes] lazyvar = Transpose( _as_lazyarray(source), axes, shape=shape, passthrough=range(source.ndim), dtype=source.dtype, ) coords = {} for name, c in source.coords.items(): if c.ndim > 1: newcdims = [] for d in dims: if d in c.dims: newcdims.append(d) caxes = [c.dims.index(d) for d in newcdims] coords[name] = transpose(c, caxes) coords[name] = c return xr.DataArray( lazyvar, dims=dims, coords=coords, attrs=source.attrs, name=lazyvar.name, )
[docs] def isel(source: xr.DataArray, **indices) -> xr.DataArray: """Index named dimensions with integers, slice objects or integer arrays""" advanced_indices = [] for dim in list(indices): assert dim in source.dims, ( f"indexed dimension {dim} not used by source," f" which has dimensions {source.dims}" ) if not isinstance(indices[dim], (int, slice)): advanced_indices.append(source.dims.index(dim)) # indices[dim] = xr.Variable([source.dims[advanced_indices[0]]], np.asarray(indices[dim], dtype=np.intp)) indices[dim] = xr.Variable( ["__newdim"], np.asarray(indices[dim], dtype=np.intp) ) # Final slices per dimension slices = tuple([indices.get(dim, slice(None)) for dim in source.dims]) # Determine final shape shape = [] dims = [] passthrough = {} advanced_added = False for i, (dim, slc, l) in enumerate(zip(source.dims, slices, source.shape)): if i not in advanced_indices: # Slice is integer or slice object. if integer, it will be sliced out # so it does not contribute to the final shape if isinstance(slc, slice): start, stop, stride = slc.indices(l) dims.append(dim) passthrough[len(shape)] = i shape.append(len(range(start, stop, stride))) elif not advanced_added: # First advanced slice. Add the shape produced by the broadcast combination # of advanced indices assert max(advanced_indices) - min(advanced_indices) + 1 == len( advanced_indices ), "advanced indices must be side-by-side for now" advanced_shapes = [indices[source.dims[i]].shape for i in advanced_indices] for length in np.broadcast_shapes(*advanced_shapes): dims.append(f"dim_{len(shape)}") shape.append(length) advanced_added = True lazyvar = Slice(_as_lazyarray(source), shape=shape, passthrough=passthrough) lazyvar._slices.append((slices, (slice(None),) * len(shape))) coords = {} for name, c in source.coords.items(): if all(dim not in c.dims for dim in indices): coords[name] = c return xr.DataArray( lazyvar, dims=dims, coords=coords, attrs=source.attrs, name=lazyvar.name )
[docs] def horizontal_interpolation( source: xr.DataArray, x: xr.DataArray, y: xr.DataArray, *, mask: Optional[npt.ArrayLike] = None, xp: Union[str, xr.DataArray, None] = None, yp: Union[str, xr.DataArray, None] = None, ) -> xr.DataArray: """Two-dimensional linear interpolation For target coordinates that fall into cells with one or more masked corners, nearest-neighbor interpolation is used to find the nearest unmasked point. Args: source: source variable x: x-coordinates at which to evaluate the interpolated values y: y-coordinates at which to evaluate the interpolated values mask: mask for the source variable. If None, all points of the source variable are assumed to be valid. xp: x-coordinates of the source variable. If None, the longitudes of the source variable are used. yp: y-coordinates of the source variable. If None, the latitudes of the source variable are used. """ if xp is None: xp = source.getm.longitude if xp is None: raise Exception( f"Variable {source.name} does not have a valid longitude coordinate." ) elif isinstance(xp, str): xp = source.coords[xp] if yp is None: yp = source.getm.latitude if yp is None: raise Exception( f"Variable {source.name} does not have a valid latitude coordinate." ) elif isinstance(yp, str): yp = source.coords[yp] assert xp.ndim == 1 assert yp.ndim == 1 assert np.isfinite(x).all(), f"Some target x-coordinates are non-finite: {x}" assert np.isfinite(y).all(), f"Some target y-coordinates are non-finite: {y}" x, y = np.broadcast_arrays(x, y) ixdim = source.dims.index(xp.dims[0]) iydim = source.dims.index(yp.dims[0]) assert abs(ixdim - iydim) == 1, "x and y dimensions must be distinct and adjacent" dimensions = {0: (), 1: (xp.dims[0],), 2: (yp.dims[0], xp.dims[-1])}[x.ndim] shape = ( source.shape[: min(ixdim, iydim)] + x.shape + source.shape[max(ixdim, iydim) + 1 :] ) kwargs = {"ndim_trailing": source.ndim - max(ixdim, iydim) - 1, "mask": mask} if ixdim > iydim: # Dimension order: y first, then x ip = pygetm.util.interpolate.Linear2DGridInterpolator(y, x, yp, xp, **kwargs) else: # Dimension order: x first, then y dimensions = dimensions[::-1] ip = pygetm.util.interpolate.Linear2DGridInterpolator(x, y, xp, yp, **kwargs) # Coordinates for the interpolated variable x_name, y_name = xp.name, yp.name if x_name in dimensions and x.ndim > 1: x_name = x_name + "_" if y_name in dimensions and y.ndim > 1: y_name = y_name + "_" coords = {k: v for k, v in source.coords.items() if k not in {xp.name, yp.name}} coords[x_name] = xr.DataArray(x, dims=dimensions, name=x_name, attrs=xp.attrs) coords[y_name] = xr.DataArray(y, dims=dimensions, name=y_name, attrs=yp.attrs) # Dimensions of the interpolated variable dims = ( source.dims[: min(ixdim, iydim)] + dimensions + source.dims[max(ixdim, iydim) + 1 :] ) lazyvar = HorizontalInterpolation( ip, _as_lazyarray(source), shape, min(ixdim, iydim), source.ndim - max(ixdim, iydim) - 1, ) return xr.DataArray( lazyvar, dims=dims, coords=coords, attrs=source.attrs, name=lazyvar.name )
[docs] class HorizontalInterpolation(UnaryOperator): def __init__( self, ip: pygetm.util.interpolate.Linear2DGridInterpolator, source: LazyArray, shape: Iterable[int], npre: int, npost: int, **kwargs, ): super().__init__(source, shape=shape, dtype=float, **kwargs) self._ip = ip self.npre = npre self.npost = npost def __array__(self, dtype=None, copy=None) -> np.ndarray: return self._ip(np.asarray(self._source)) def __getitem__(self, slices) -> np.ndarray: src_slice, tgt_slice = [Ellipsis], [Ellipsis] ntrailing_dim_removed = 0 for i, s in enumerate(self._finalize_slices(slices)): if i < self.npre: # prefixed dimension src_slice.insert(i, s) elif i >= self.ndim - self.npost: # trailing dimension if isinstance(s, (int, np.integer)): ntrailing_dim_removed += 1 tgt_slice.append(0) src_slice.append(s) else: assert ( isinstance(s, slice) and (s.start is None or s.start == 0) and (s.stop is None or s.stop == self.shape[i]) and s.step is None ), repr(s) source = np.asarray(self._source[tuple(src_slice)]) source.shape = source.shape + (1,) * ntrailing_dim_removed result = self._ip(source) return result[tuple(tgt_slice)]
[docs] def vertical_interpolation( source: xr.DataArray, target_z: npt.ArrayLike, itargetdim: int = 0 ) -> xr.DataArray: source_z = source.getm.z target_z = np.asarray(target_z) if source_z is None: raise Exception( f"Variable {source.name} does not have a valid depth coordinate." ) assert source_z.ndim == 1 izdim = source.dims.index(source_z.dims[0]) # assert source.ndim - izdim == target_z.ndim # assert source.shape[izdim + 1:izdim + 3] == target_z.shape[1:], f'{source.shape[izdim + 1:izdim + 3]} vs {target_z.shape[1:]}' target2sourcedim = {} isourcedim = 0 for i, l in enumerate(target_z.shape): if i == itargetdim: isourcedim = izdim else: while ( isourcedim != izdim and isourcedim < source.ndim and l != source.shape[isourcedim] ): isourcedim += 1 assert isourcedim != izdim, ( f"Dimension with length {l} should precede depth dimension {izdim}" f" in {source.name}, which has shape {source.shape}" ) assert isourcedim < source.ndim, ( f"Dimension with length {l} expected after depth dimension {izdim}" f" in {source.name}, which has shape {source.shape}" ) target2sourcedim[i] = isourcedim isourcedim += 1 coords = {} for n, c in source.coords.items(): if n == source_z.name: coords[n + "_"] = ( [source.dims[target2sourcedim[i]] for i in range(target_z.ndim)], target_z, ) else: coords[n] = c lazyvar = VerticalInterpolation( _as_lazyarray(source), target_z, izdim, source_z.values, itargetdim ) return xr.DataArray( lazyvar, dims=source.dims, coords=coords, attrs=source.attrs, name=lazyvar.name )
[docs] class VerticalInterpolation(UnaryOperator): def __init__( self, source: LazyArray, z: np.ndarray, izdim: int, source_z: np.ndarray, axis: int = 0, **kwargs, ): self.izdim = izdim passthrough = [idim for idim in range(source.ndim) if idim != self.izdim] shape = list(source.shape) shape[self.izdim] = z.shape[axis] super().__init__( source, shape=shape, dtype=float, passthrough=passthrough, **kwargs ) self.z = z self.axis = axis self.source_z = source_z if (self.source_z >= 0.0).all(): self.source_z = -self.source_z
[docs] def apply(self, source: np.ndarray, dtype=None) -> np.ndarray: return pygetm.util.interpolate.interp_1d( self.z, self.source_z, source, axis=self.axis )
[docs] def temporal_interpolation( source: xr.DataArray, climatology: bool = False, comm: MPI.Comm = MPI.COMM_SELF, logger: Optional[logging.Logger] = None, ) -> xr.DataArray: time_coord = source.getm.time assert time_coord is not None, "No time coordinate found" itimedim = source.dims.index(time_coord.dims[0]) lazyvar = TemporalInterpolation( _as_lazyarray(source), itimedim, time_coord.values, climatology, comm=comm, logger=logger, ) dims = [d for i, d in enumerate(source.dims) if i != lazyvar._itimedim] coords = {time_coord.dims[0]: lazyvar._timecoord} for n, c in source.coords.items(): if time_coord.dims[0] not in c.dims: coords[n] = c return xr.DataArray( lazyvar, dims=dims, coords=coords, attrs=source.attrs, name=lazyvar.name )
def _as_lazyarray(array: xr.DataArray) -> LazyArray: variable = array.variable if isinstance(variable._data, LazyArray): return variable._data else: name = array.name if name is None: name = object.__repr__(variable._data) if "source" in array.encoding: name += " from " + array.encoding["source"] return Wrap(variable, name=name)
[docs] class Cache(UnaryOperator): def __init__(self, source: LazyArray, idim: int, n: int): super().__init__(source) self._cache: Optional[np.ndarray] = None self.istart = 0 self.n = n self.idim = idim def __getitem__(self, slices) -> np.ndarray: slices = list(self._finalize_slices(slices)) i = slices[self.idim] if not isinstance(i, (int, np.integer)): return self._source[tuple(slices)] if self._cache is None or i < self.istart or i >= self.istart + self.n: self.istart = i slices[self.idim] = slice(i, i + self.n) self._cache = self._source[tuple(slices)] slices[self.idim] = i - self.istart return self._cache[tuple(slices)] def __array__(self, dtype=None, copy=None) -> np.ndarray: return np.asarray(self._source, dtype=dtype)
[docs] class TemporalInterpolation(UnaryOperator): __slots__ = ( "_current", "_itimedim", "_numnow", "_numnext", "_slope", "_inext", "_next", "_slices", "climatology", "_year", "_timevalues", ) MAX_CACHE_SIZE = 0 NTIME_CACHE = 1 def __init__( self, source: LazyArray, itimedim: int, times: npt.ArrayLike, climatology: bool, dtype: npt.DTypeLike = float, comm: MPI.Comm = MPI.COMM_SELF, logger: Optional[logging.Logger] = None, **kwargs, ): if self.NTIME_CACHE > 1: source = Cache(source, itimedim, self.NTIME_CACHE) shape = list(source.shape) self._itimedim = itimedim ntime = shape[self._itimedim] if ntime <= 1: raise Exception( f"Cannot interpolate {source.name} in time because" f" its time dimension has length {ntime}." ) shape.pop(self._itimedim) super().__init__(source, shape=shape, dtype=dtype, **kwargs) self._current = np.empty(shape, dtype=self.dtype) self.times = np.asarray(times) self._timecoord = xr.DataArray(self.times[0]) self._timevalues = self._timecoord.values self._numnow = None self._numnext = 0.0 self._slope = 0.0 self._inext = -1 self._next = 0.0 self._slices: list[Union[int, slice]] = [slice(None)] * source.ndim self.climatology = climatology self._year = self.times[0].year self._cache = {} self._use_cache = False if climatology and not all(time.year == self._year for time in self.times): raise Exception( f"{self._source_name} cannot be used as climatology because" " it spans more than one calendar year" f" ({self.times[0]} - {self.times[-1]})." ) if climatology and self.MAX_CACHE_SIZE > 0: memory = np.asarray(source.size * (self.dtype.itemsize / 1024 / 1024)) # Ensure all subdomains take same caching decision # In part because performance generally does not improve until all # subdomains cache, but foremost because underlying LazyArrays may # use collective MPI operations in which all subdomains MUST participate comm.Allreduce(MPI.IN_PLACE, memory, MPI.MAX) self._use_cache = memory < self.MAX_CACHE_SIZE if logger: if self._use_cache: logger.info( f"{source.name} is a climatology that will be cached in memory." f" This requires {memory:.1f} MB, which is less than the" f" maximum allowed cache size of {self.MAX_CACHE_SIZE} MB" ) else: logger.info( f"Not caching {source.name} because this would require" f" {memory:.1f} MB, which is more than the maximum allowed" f" cache size of {self.MAX_CACHE_SIZE} MB" ) def __array__(self, dtype=None, copy=None) -> np.ndarray: return self._current def __getitem__(self, slices) -> np.ndarray: return self._current[slices]
[docs] def is_time_varying(self) -> bool: return True
[docs] def update( self, time: cftime.datetime, numtime: Optional[np.longdouble] = None ) -> bool: if numtime is None: numtime = time.toordinal(fractional=True) if self._numnow is None: # First call to update self._start(time) elif numtime <= self._numnow: # Subsequent call to update, but time has not increased # If equal to previous time, we are done. If smaller, rewind if numtime == self._numnow: return False self._start(time) while self._numnext < numtime: self._move_to_next(time) # Do linear interpolation if numtime == self._numnext: self._current[...] = self._next else: np.multiply(self._slope, numtime - self._numnext, out=self._current) self._current += self._next # Save current time self._numnow = numtime self._timevalues[...] = time return True
def _start(self, time: cftime.datetime): if time.calendar != self.times[0].calendar: try: time = time.change_calendar(self.times[0].calendar) except ValueError: raise Exception( f"Calendar {self.times[0].calendar} used by {self._source_name}" f" is incompatible with simulation calendar {time.calendar}." ) # Find the time index (_inext) just before the left bound of the window we need. # We will subsequently call _move_to_next twice to load the actual window. # If a time in the series exactly matches the requested time, this will be the # right bound of the time window (hence side="left" below), except if that # is the very first time in the time series. Then it will be the left bound. if self.climatology: clim_time = time.replace(year=self._year) self._inext = self.times.searchsorted(clim_time, side="left") - 2 self._year = time.year if self._inext < -1: self._inext += self.times.size self._year -= 1 else: # Make sure the time series does not start after the requested time. if time < self.times[0]: raise Exception( f"Cannot interpolate {self._source_name} to value at {time}," f" because time series starts only at {self.times[0]}." ) self._inext = max(-1, self.times.searchsorted(time, side="left") - 2) # Load left and right bound of the window encompassing the current time. # After that, all information for linear interpolation (_next, _slope) # is available. self._move_to_next(time) self._move_to_next(time) def _move_to_next(self, time: cftime.datetime): # Move to next record self._inext += 1 if self._inext == self.times.size: if self.climatology: self._inext = 0 self._year += 1 else: raise Exception( f"Cannot interpolate {self._source_name} to value at {time}" f" because end of time series was reached ({self.times[-1]})." ) old = self._next numold = self._numnext if self._inext in self._cache: self._next = self._cache[self._inext] else: self._slices[self._itimedim] = self._inext self._next = np.asarray(self._source[tuple(self._slices)], dtype=self.dtype) if self._use_cache: self._cache[self._inext] = self._next next_time = self.times[self._inext] if self.climatology: next_time = next_time.replace(year=self._year) self._numnext = next_time.toordinal(fractional=True) self._slope = (self._next - old) / (self._numnext - numold)
[docs] def slicespec2string(s: Union[tuple, slice, int]) -> str: if isinstance(s, slice): start = "" if s.start is None else f"{s.start}" stop = "" if s.stop is None else f"{s.stop}" if s.step in (None, 1): return f"{start}:{stop}" return f"{start}:{stop}:{s.step}" elif isinstance(s, tuple): return ",".join([slicespec2string(item) for item in s]) return f"{s!r}"
[docs] def debug_nc_reads(logger: Optional[logging.Logger] = None): """Hook into :mod:`xarray` so that every read from a NetCDF file is written to the log. """ import xarray.backends.netCDF4_ if logger is None: logger = logging.getLogger("pygetm.input") logger.setLevel(logging.DEBUG) class NetCDF4ArrayWrapper2(xarray.backends.netCDF4_.NetCDF4ArrayWrapper): __slots__ = () def _getitem(self, key): logger.debug( f"Reading {self.variable_name}[{slicespec2string(key)}]" f" from {self.datastore._filename}" ) return super()._getitem(key) xarray.backends.netCDF4_.NetCDF4ArrayWrapper = NetCDF4ArrayWrapper2
[docs] class OnGrid(enum.Enum): #: Grids do not match. Spatially explicit data will require horizontal #: and - if vertically resolved - vertical interpolation. NONE = enum.auto() #: Horizontal grid matches, but vertical does not. #: Vertically resolved data will require vertical interpolation. HORIZONTAL = enum.auto() #: Horizontal and vertical grids match ALL = enum.auto()
[docs] class InputManager: def __init__(self, logger: logging.Logger): self._all_fields: list[tuple[str, LazyArray, np.ndarray]] = [] self._micro_fields: list[tuple[str, LazyArray, np.ndarray]] = [] self.logger = logger
[docs] def debug_nc_reads(self): """Hook into :mod:`xarray` so that every read from a NetCDF file is written to the log. """ nc_logger = self.logger.getChild("nc") nc_logger.setLevel(logging.DEBUG) debug_nc_reads(nc_logger)
[docs] def add( self, array: pygetm.core.Array, value: Union[numbers.Number, np.ndarray, xr.DataArray, LazyArray], periodic_lon: bool = True, on_grid: Union[bool, OnGrid] = False, include_halos: Optional[bool] = None, climatology: bool = False, mask: bool = False, updater_collection: Optional[list] = None, ): """Link an array to the provided input. If this input is constant in time, the value of the array will be set immediately. Args: array: array to assign a value to value: input to assign. If this is time-dependent, the combination of the array and its linked input will be registered; the array will then be updated to the current time whenever :meth:`update` is called. periodic_lon: whether this input covers all longitudes (i.e., the entire globe in the horizontal) and therefore has a periodic boundary. This enables efficient spatial interpolation across longitude bounds of the input, for instance, accessing 10 degrees West to 5 degrees East for an input that spans 0 to 360 degrees East. on_grid: whether the input is defined on the same grid (horizontal-only, or both horizontal and vertical) as the array that is being assigned to. If this is ``False``, the value will be spatially interpolated to the array grid. ``True`` is equivalent to :attr:`OnGrid.HORIZONTAL`. include_halos: whether to also update the halos of the array. If not provided, this default to ``True`` if the array has attributes ``_require_halos`` or ``_part_of_state``; otherwise it defaults to ``False``. climatology: whether the input describes a single climatological year (at any temporal resolution, e.g., monthly, daily) that is representative for any true year. This argument is relevant only if the provided input is time-varying. It also requires that the input does not span more than one year. mask: whether to set the array to its :attr:`pygetm.core.Array.fill_value` in all masked points. If not provided, only missing values in the input (NaNs) will be set to the fill value. This currently only has an effect when the input is non time-varying. """ if array.all_values is None or array.all_values.size == 0: # The target variable does not contain data. Typically this is because # it specifies information on the open boundaries, # of which the current (sub)domain does not have any. self.logger.warning( f"Ignoring asssignment to array {array.name}" " because it has no associated data." ) return if include_halos is None: include_halos = array.attrs.get("_require_halos", False) or array.attrs.get( "_part_of_state", False ) if not isinstance(on_grid, OnGrid): on_grid = OnGrid.HORIZONTAL if on_grid else OnGrid.NONE grid = array.grid # Obtain active area of local subdomain (including halos if # include_halos is True) and the corresponding slice in the global domain # (always excluding halos) ( local_slice, global_slice, local_shape, global_shape, ) = grid.tiling.subdomain2slices( exclude_halos=not include_halos, halox_sub=grid.halox, haloy_sub=grid.haloy ) def _map_to_grid(): # the input is already on-grid for grid_mapper in grid.input_grid_mappers: mapped_value = grid_mapper(value) if mapped_value is not None: return mapped_value else: # default grid mapping: from global domain to subdomain if value.shape[-2:] == global_shape: slc = global_slice elif value.shape[-2:] == local_shape: slc = local_slice else: raise Exception( f"{array.name}: trailing shape of values {value.shape[-2:]}" f" should match that of global domain {global_shape}" f" or local subdomain {local_shape}" ) if isinstance(value, xr.DataArray): # lazy slice indices = {value.dims[-2]: slc[-2], value.dims[-1]: slc[-1]} return isel(value, **indices) else: # immediate slice return value[slc] if isinstance(value, (numbers.Number, np.ndarray)): if array.on_boundary: pass elif array.ndim != 0 and on_grid != OnGrid.NONE: value = _map_to_grid() # Constant-in-time fill value. Set it, then forget about the array # as it will not require further updating. array.fill(value) return assert isinstance(value, xr.DataArray), ( "If value is not numeric, it should be an xarray.DataArray," f" but it is {value!r}." ) target_slice = (Ellipsis,) source_lon, source_lat = value.getm.longitude, value.getm.latitude if array.on_boundary: # Open boundary information. This can either be specified for the global # domain (e.g., when read from NetCDF), or for only the open boundary # points that fall within the local subdomain. Determine which of these. if value.ndim >= 2 and value.shape[-2:] == global_shape: # on-grid data for the global domain: # extract data at open boundary points i_bnd = grid.open_boundaries.i_glob j_bnd = grid.open_boundaries.j_glob value = isel(value, **{value.dims[-1]: i_bnd, value.dims[-2]: j_bnd}) elif np.ndim(source_lon) > 0 and np.ndim(source_lat) > 0: # Spatially explicit input: # interpolate horizontally to open boundary coordinates if source_lon.ndim != 1: raise Exception( f"Unsuitable shape {source_lon.shape} of longitude coordinate" f" {source_lon.name}. Off-grid boundary information can be used" f" only if its longitude is 1D." ) if source_lat.ndim != 1: raise Exception( f"Unsuitable shape {source_lat.shape} of latitude coordinate" f" {source_lat.name}. Off-grid boundary information can be used" f" only if its latitude is 1D." ) ilondim = value.dims.index(source_lon.dims[0]) ilatdim = value.dims.index(source_lat.dims[0]) if ilondim != ilatdim: lon_bnd = grid.open_boundaries.lon.all_values lat_bnd = grid.open_boundaries.lat.all_values value = limit_region( value, lon_bnd.min(), lon_bnd.max(), lat_bnd.min(), lat_bnd.max(), periodic_lon=periodic_lon, ) ip_mask = None if value.getm.time is not None and not array.z: itimedim = value.dims.index(value.getm.time.dims[0]) slc: list[Union[int, slice]] = [slice(None)] * value.ndim slc[itimedim] = 0 ip_mask = np.isnan(value[tuple(slc)]) value = pygetm.input.horizontal_interpolation( value, lon_bnd, lat_bnd, mask=ip_mask ) if array.z and value.getm.z is not None and value.getm.z.ndim == 1: # Source and target arrays are depth-explicit and the source depth is 1D # Ensure it is the last (fastest varying) dimension izdim = value.dims.index(value.getm.z.dims[0]) if izdim != value.ndim - 1: axes = list(range(value.ndim)) axes.append(axes.pop(izdim)) value = transpose(value, axes) idim = value.ndim - (2 if array.z else 1) if value.shape[idim] == grid.open_boundaries.np_glob: # The source array covers all open boundaries (global domain). # If the subdomain only has a subset of those, slice out only the points # that fall within the current subdomain local_to_global = grid.open_boundaries.local_to_global if local_to_global: value = concatenate_slices(value, idim, local_to_global) elif value.shape[idim] != grid.open_boundaries.np: raise Exception( f"Extent of dimension {idim} of {value.name} is not compatible with" f" open boundaries. It should have length" f" {grid.open_boundaries.np_glob} (number of open boundary" f" points in the global domain) or {grid.open_boundaries.np}" f" (number of open boundary points in the current subdomain)." f" Its actual extent is {value.shape[idim]}." ) elif array.ndim != 0: # The target is a normal 2D (horizontal-only) or 3D (depth-explicit) array # The source data can either be on the native model grid, or at an # arbitrary lon, lat grid. In the latter case, we interpolate in space. assert array.all_values.shape[-2:] == local_shape target_slice = local_slice if on_grid != OnGrid.NONE: # the input is already on-grid value = _map_to_grid() elif np.ndim(source_lon) == 0 and np.ndim(source_lat) == 0: # time series for single location value = value.expand_dims(("y", "x"), (value.ndim, value.ndim + 1)) else: # interpolate horizontally to local array INCLUDING halos lon = grid.lon.all_values[target_slice] lat = grid.lat.all_values[target_slice] assert not np.isnan(lon).any() assert not np.isnan(lat).any() value = limit_region( value, lon.min(), lon.max(), lat.min(), lat.max(), periodic_lon=periodic_lon, ) value = horizontal_interpolation(value, lon, lat) if value.getm.time is not None: # The source data is time-dependent; during the simulation it will be # interpolated in time. if value.getm.time.size > 1: value = temporal_interpolation( value, climatology=climatology, comm=grid.tiling.comm, logger=self.logger, ) elif value.getm.time.dims: time = value.getm.time.values.flat[0] self.logger.warning( f"{array.name} is set to {value.name}, which has only one time" f" point {time}. The value from this time will be used now." f" {array.name} will not be further updated by the input manager" " at runtime." ) itimedim = value.dims.index(value.getm.time.dims[0]) slc = [slice(None)] * value.ndim slc[itimedim] = 0 value = value[tuple(slc)] if array.z and on_grid != OnGrid.ALL: # The target is a depth-explicit array. # The source must be defined on z coordinates # and interpolated to our [time-varying] depths coord_source = grid.open_boundaries if array.on_boundary else grid z_coordinate = coord_source.zc if array.z == CENTERS else coord_source.zf z_coordinate.saved = True value = vertical_interpolation( value, z_coordinate.all_values[target_slice], itargetdim=1 if array.on_boundary else 0, ) target = array.all_values[target_slice] try: np.broadcast_shapes(value.shape, target.shape) except ValueError: assert ( False ), f"Source shape {value.shape} does not match target shape {target.shape}" data = value.variable._data if isinstance(data, LazyArray) and data.is_time_varying(): time_varying = array.attrs.get("_time_varying", TimeVarying.MICRO) suffix = " on macrotimestep" if time_varying == TimeVarying.MACRO else "" self.logger.info( f"{array.name} will be updated dynamically from {data.name}{suffix}" ) info = (array.name, data, target) if updater_collection is not None: updater_collection.append(info) else: self._all_fields.append(info) if time_varying == TimeVarying.MICRO: self._micro_fields.append(info) else: target[...] = value finite = np.isfinite(target) unmasked = ~array.all_mask[target_slice] if array.fill_value is not None: # Fill masked points with the fill value. Either we fill all masked # points (if mask=True) or only those that are not finite. keep_values = unmasked if mask else unmasked | finite target[~keep_values] = array.fill_value if not finite.all(where=unmasked): n_unmasked = unmasked.sum() n_bad = n_unmasked - finite.sum(where=unmasked) self.logger.warning( f"{array.name} is set to {value.name}, which is not finite" f" (e.g., NaN) in {n_bad} of {n_unmasked} unmasked points." ) minval = target.min(where=unmasked, initial=np.inf) maxval = target.max(where=unmasked, initial=-np.inf) self.logger.info( f"{array.name} is set to time-invariant {value.name}" f" (minimum: {minval}, maximum: {maxval})" )
[docs] def update( self, time: cftime.datetime, macro: bool = True, fields: Optional[list] = None ): """Update all arrays linked to time-dependent inputs to the current time. Args: time: current time macro: whether to also update arrays that were marked as only relevant for the macro (3D) time step """ numtime = time.toordinal(fractional=True) if fields is None: fields = self._all_fields if macro else self._micro_fields for name, source, target in fields: self.logger.debug(f"updating {name}") source.update(time, numtime) target[...] = source