import numbers
import operator
from typing import Optional, Union, Literal, Any, TYPE_CHECKING
import logging
import functools
from collections.abc import Iterable, Mapping, Sequence, Callable
import numpy as np
import numpy.lib.mixins
import numpy.typing as npt
import xarray as xr
from . import _pygetm
from . import parallel
from .constants import CENTERS, INTERFACES, FILL_VALUE, CoordinateType, CellType
if TYPE_CHECKING:
import pygetm.open_boundaries
import pygetm.rivers
def _noop(*args, **kwargs):
pass
[docs]
class Rotator:
__slots__ = ("_sin", "_cos")
def __init__(self, rotation: npt.ArrayLike):
"""Rotator for velocity fields (geocentric to model-centric and vice versa)
Args:
rotation: clockwise rotation of grid relative to true North
"""
rotation = np.asarray(rotation)
self._sin = np.sin(rotation)
self._cos = np.cos(rotation)
# hardcode cos(0.5*pi)=0 to increase precision in 90 degree rotation tests
self._cos[rotation == 0.5 * np.pi] = 0.0
def __call__(
self, u: npt.ArrayLike, v: npt.ArrayLike, to_grid: bool = True
) -> tuple[npt.ArrayLike, npt.ArrayLike]:
if to_grid:
# clockwise
u_new = u * self._cos + v * self._sin
v_new = v * self._cos - u * self._sin
else:
# counterclockwise
u_new = u * self._cos - v * self._sin
v_new = u * self._sin + v * self._cos
return u_new, v_new
[docs]
class Locator:
def __init__(
self,
mask: np.ndarray,
*,
x: Optional[np.ndarray] = None,
y: Optional[np.ndarray] = None,
lon: Optional[np.ndarray] = None,
lat: Optional[np.ndarray] = None,
):
assert mask.ndim == 2
self.mask = mask
self.x = x
self.y = y
self.lon = lon
self.lat = lat
[docs]
def __call__(
self,
x: npt.ArrayLike,
y: npt.ArrayLike,
*,
coordinate_type: CoordinateType,
valid_cell_types: Optional[Iterable[CellType]] = None,
) -> tuple[int, int]:
"""Locate the unmasked grid cell nearest to the specified location.
Args:
x: x coordinate of the location
y: y coordinate of the location
coordinate_type: type of coordinates
(LONLAT for spherical, XY for Cartesian coordinates)
Returns:
(i, j): indices of the nearest unmasked grid cell
"""
if coordinate_type == CoordinateType.LONLAT:
assert self.lon is not None and self.lat is not None
allx, ally = self.lon, self.lat
elif coordinate_type == CoordinateType.XY:
assert self.x is not None and self.y is not None
allx, ally = self.x, self.y
else:
allx = np.arange(self.mask.shape[-1], dtype=float)[np.newaxis, :]
ally = np.arange(self.mask.shape[-2], dtype=float)[:, np.newaxis]
# Location is specified by x, y coordinate.
# Look up nearest unmasked grid cell.
x, y = np.broadcast_arrays(x, y)
dist = (allx - x[..., np.newaxis, np.newaxis]) ** 2 + (
ally - y[..., np.newaxis, np.newaxis]
) ** 2
if valid_cell_types is not None:
invalid = True
for cell_type in valid_cell_types:
invalid &= self.mask != cell_type
dist[..., invalid] = np.inf
flat_dist = np.reshape(dist, x.shape + (-1,))
idx = np.nanargmin(flat_dist, axis=-1)
j, i = np.unravel_index(idx, self.mask.shape)
assert (
valid_cell_types is None
or np.isin(self.mask[j, i], list(valid_cell_types)).all()
)
if i.ndim == 0:
return int(i), int(j)
return i, j
[docs]
class Grid(_pygetm.Grid):
_domain_arrays = (
"x",
"y",
"lon",
"lat",
"dx",
"dy",
"area",
"cor",
"H",
"mask",
"z0b_min",
)
_derived_metric_arrays = ("idx", "idy", "iarea")
_readonly_arrays = _domain_arrays + _derived_metric_arrays
_new_2d_arrays = ("D", "z0b", "alpha") + _derived_metric_arrays
_new_3d_arrays = ("hn", "zc", "zf")
_fortran_arrays = _domain_arrays + _new_2d_arrays + _new_3d_arrays
_array_args = {
"x": dict(
units="m",
long_name="x-coordinate",
attrs=dict(
standard_name="projection_x_coordinate", axis="X", _time_varying=False
),
),
"y": dict(
units="m",
long_name="y-coordinate",
attrs=dict(
standard_name="projection_y_coordinate", axis="Y", _time_varying=False
),
),
"lon": dict(
units="degrees_east",
long_name="longitude",
attrs=dict(standard_name="longitude", axis="X", _time_varying=False),
),
"lat": dict(
units="degrees_north",
long_name="latitude",
attrs=dict(standard_name="latitude", axis="Y", _time_varying=False),
),
"dx": dict(
units="m",
long_name="cell length in x-direction",
attrs=dict(
_time_varying=False, _valid_at=(CellType.BOUNDARY, CellType.EDGE_X)
),
),
"dy": dict(
units="m",
long_name="cell length in y-direction",
attrs=dict(
_time_varying=False, _valid_at=(CellType.BOUNDARY, CellType.EDGE_Y)
),
),
"idx": dict(
units="m-1",
long_name="inverse of cell length in x-direction",
attrs=dict(_time_varying=False),
),
"idy": dict(
units="m-1",
long_name="inverse of cell length in y-direction",
attrs=dict(_time_varying=False),
),
"H": dict(
units="m",
long_name="water depth at rest",
attrs=dict(
_time_varying=False,
standard_name="sea_floor_depth_below_geopotential_datum",
_valid_at=(
CellType.BOUNDARY,
CellType.MIRROR_INT,
CellType.MIRROR_EXT,
),
),
),
"D": dict(
units="m",
long_name="water depth",
attrs=dict(
standard_name="sea_floor_depth_below_sea_surface",
_valid_at=(
CellType.BOUNDARY,
CellType.MIRROR_INT,
CellType.MIRROR_EXT,
CellType.EDGE_X, # must be non-zero for vel=transport/D
CellType.EDGE_Y, # must be non-zero for vel=transport/D
),
),
),
"mask": dict(
long_name="mask",
attrs=dict(
_time_varying=False,
_valid_at=(
CellType.BOUNDARY,
CellType.MIRROR_INT,
CellType.MIRROR_EXT,
CellType.EDGE_X,
CellType.EDGE_Y,
),
),
fill_value=0,
),
"area": dict(
units="m2",
long_name="cell area",
attrs=dict(standard_name="cell_area", _time_varying=False),
),
"iarea": dict(
units="m-2",
long_name="inverse of cell area",
attrs=dict(_time_varying=False),
),
"cor": dict(
units="s-1",
long_name="Coriolis parameter",
attrs=dict(standard_name="coriolis_parameter", _time_varying=False),
),
"hn": dict(
units="m",
long_name="cell thickness",
attrs=dict(
standard_name="cell_thickness",
_valid_at=(
CellType.BOUNDARY,
CellType.MIRROR_INT,
CellType.MIRROR_EXT,
CellType.EDGE_X, # must be non-zero for vel=transport/D
CellType.EDGE_Y, # must be non-zero for vel=transport/D
),
),
),
"zc": dict(
units="m",
long_name="height",
attrs=dict(
axis="Z", positive="up", standard_name="height_above_geopotential_datum"
),
),
"zf": dict(
units="m",
long_name="interface height",
attrs=dict(
axis="Z", positive="up", standard_name="height_above_geopotential_datum"
),
),
"z0b": dict(units="m", long_name="hydrodynamic bottom roughness"),
"z0b_min": dict(
units="m",
long_name="minimum hydrodynamic bottom roughness",
attrs=dict(_time_varying=False),
),
"alpha": dict(units="1", long_name="dampening"),
"rotation": dict(
units="rad",
long_name="grid rotation with respect to true North",
fill_value=np.nan,
attrs=dict(_time_varying=False),
),
}
_all_arrays = tuple(f"_{n}" for n in _array_args)
__slots__ = _all_arrays + (
"ioffset",
"joffset",
"postfix",
"ugrid",
"vgrid",
"xgrid",
"_rotator",
"Dclip",
"z",
"zo",
"zin",
"zio",
"ho",
"hhalf",
"open_boundaries",
"input_manager",
"default_output_transforms",
"input_grid_mappers",
"rivers",
"overlap",
"_interpolators",
"_interior",
"_edge_x",
"_edge_y",
"_land",
"_land3d",
"_land3d_if",
"_water",
"_water_nohalo",
"horizontal_coordinates",
"extra_output_coordinates",
"_mirrors",
"tiling",
"fields",
"mask3d",
"bottom_indices",
"_work",
"_masks",
)
open_boundaries: "pygetm.open_boundaries.LocalOpenBoundaryCollection"
rivers: "pygetm.rivers.LocalRiverCollection"
def __init__(
self,
nx: int,
ny: int,
nz: Optional[int] = None,
*,
halox: int = 0,
haloy: int = 0,
postfix: str = "",
ugrid: Optional["Grid"] = None,
vgrid: Optional["Grid"] = None,
xgrid: Optional["Grid"] = None,
ioffset: int = 1,
joffset: int = 1,
overlap: int = 0,
istart: int = 1,
jstart: int = 1,
fields: Optional[Mapping[str, "Array"]] = None,
tiling: Optional[parallel.Tiling] = None,
):
super().__init__(nx, ny, nz or 0, halox, haloy, istart, jstart)
self.postfix = postfix
self.ioffset = ioffset
self.joffset = joffset
self.overlap = overlap
self.ugrid = ugrid
self.vgrid = vgrid
self.xgrid = xgrid
self.fields: Mapping[str, "Array"] = {} if fields is None else fields
self.tiling = tiling
self._interior = (Ellipsis, slice(haloy, haloy + ny), slice(halox, halox + nx))
self._interpolators: dict["Grid", Callable[[np.ndarray, np.ndarray], None]] = {}
self._mirrors: dict[
"Grid", Optional[tuple[tuple[slice, ...], tuple[slice, ...]]]
] = {}
self.horizontal_coordinates: list["Array"] = []
self.extra_output_coordinates = []
self._work = self.array(fill_value=np.nan)
self._masks: dict[frozenset[CellType], np.ndarray] = {}
self._edge_x = False
self._edge_y = False
[docs]
def create_array(self, name: str) -> Optional["Array"]:
assert name in self._array_args
if name in self._new_3d_arrays and not self.nz:
return None
kwargs = dict(fill_value=FILL_VALUE)
kwargs.update(self._array_args[name])
kwargs.setdefault("attrs", {}).setdefault("_valid_at", (CellType.BOUNDARY,))
if name in self._fortran_arrays:
array = Array(name=name + self.postfix, **kwargs)
self.wrap(array, name.encode("ascii"))
else:
array = Array.create(grid=self, name=name + self.postfix, **kwargs)
setattr(self, f"_{name}", array)
return array
[docs]
def freeze(self):
"""Freeze all grid attributes. This will calculate derived metrics
such as the inverse of cell height/width/area and initialize elevation
and water depth. It subsequently makes most attributes read-only."""
for name in self._new_2d_arrays + self._new_3d_arrays:
self.create_array(name)
for name in self._all_arrays:
if not hasattr(self, name):
setattr(self, name, None)
with np.errstate(divide="ignore"):
self._iarea.all_values = 1.0 / self._area.all_values
self._idx.all_values = 1.0 / self._dx.all_values
self._idy.all_values = 1.0 / self._dy.all_values
if self._rotation is not None:
self._rotator = Rotator(self._rotation.all_values)
else:
self._rotator = None
self._land = self.get_mask(
(
CellType.ACTIVE,
CellType.BOUNDARY,
CellType.MIRROR_INT,
CellType.MIRROR_EXT,
)
)
self._water = self.get_mask((CellType.UNRESOLVED,))
self._water_nohalo = np.full_like(self._water, False)
self._water_nohalo[self._interior] = self._water[self._interior]
if self.nz:
# Initialize center and interface depth coordinates to coincide with
# bottom to ensure land points have a sensible value when plotting.
# For this, H should be valid also on land, where it then represents
# (negative) surface elevation. Thus, it needs to be done before H is
# masked on land below.
self.zc.all_values = -self.H.all_values
self.zf.all_values = -self.H.all_values
self.H.all_values[self.H.all_mask] = FILL_VALUE
assert np.isfinite(self.H.all_values).all(where=~self.H.all_mask)
self.z0b_min.all_values[self.z0b_min.all_mask] = FILL_VALUE
for name in self._readonly_arrays:
array = getattr(self, name, None)
if array is not None:
array.all_values.flags.writeable = False
self.z0b.all_values = self.z0b_min.all_values
# Default water depth follows bathymetry (elevation=0)
self.D.all_values[self._water] = self.H.all_values[self._water]
self.Dclip = self.D
for child in (self.ugrid, self.vgrid, self.xgrid):
if child:
child.freeze()
[docs]
def close_flux_interfaces(self):
"""Mask U and V points that do not have two bordering wet T points"""
tdry = self.mask.all_values == CellType.UNRESOLVED
umask = tdry[:, :-1] | tdry[:, 1:]
vmask = tdry[:-1, :] | tdry[1:, :]
umask &= self.ugrid.mask.all_values[:, :-1] == CellType.ACTIVE
vmask &= self.vgrid.mask.all_values[:-1, :] == CellType.ACTIVE
self.ugrid.mask.all_values[:, :-1][umask] = CellType.UNRESOLVED
self.vgrid.mask.all_values[:-1, :][vmask] = CellType.UNRESOLVED
self.ugrid.mask.update_halos()
self.vgrid.mask.update_halos()
[docs]
def get_mask(
self, values: Iterable[CellType], z: Literal[None, CENTERS, INTERFACES] = None
) -> np.ndarray:
"""Get a Boolean mask that indicates where the cell type differs from
the specified values.
Args:
values: cell types that should not be masked
z: if CENTERS or INTERFACES, return a 3D mask for the specified
vertical coordinate type
"""
def _cache_mask(key, mask: np.ndarray):
mask.flags.writeable = False
self._masks[key] = mask
def _get_horizontal_mask(valid_mask_values) -> np.ndarray:
if valid_mask_values not in self._masks:
valid = np.isin(self.mask.all_values, list(valid_mask_values))
if CellType.EDGE_X in valid_mask_values:
valid |= self._edge_x
if CellType.EDGE_Y in valid_mask_values:
valid |= self._edge_y
_cache_mask(key=valid_mask_values, mask=~valid)
return self._masks[valid_mask_values]
def _make_3d_mask(horizontal_mask, z) -> np.ndarray:
if hasattr(self, "_land3d"):
att = "_land3d_if" if z == INTERFACES else "_land3d"
return horizontal_mask | getattr(self, att)
return horizontal_mask[np.newaxis, ...]
valid_mask_values = frozenset(values)
key = (valid_mask_values, z)
if key not in self._masks:
mask = _get_horizontal_mask(valid_mask_values)
if z:
mask = _make_3d_mask(mask, z)
_cache_mask(key=key, mask=mask)
return self._masks[key]
[docs]
def interpolator(self, target: "Grid") -> Callable[[np.ndarray, np.ndarray], None]:
ip = self._interpolators.get(target)
if ip:
return ip
def _assign(x, y, xslice, yslice):
y[yslice] = x[xslice]
# assert self.domain is target.domain
if self.ioffset == target.ioffset + 1 and self.joffset == target.joffset:
# from U to T
ip = functools.partial(_pygetm.interp_x, offset=1)
elif self.ioffset == target.ioffset - 1 and self.joffset == target.joffset:
# from T to U
ip = functools.partial(_pygetm.interp_x, offset=0)
elif self.joffset == target.joffset + 1 and self.ioffset == target.ioffset:
# from V to T
ip = functools.partial(_pygetm.interp_y, offset=1)
elif self.joffset == target.joffset - 1 and self.ioffset == target.ioffset:
# from T to V
ip = functools.partial(_pygetm.interp_y, offset=0)
elif self.ioffset == target.ioffset - 1 and self.joffset == target.joffset - 1:
# from X to T
ip = functools.partial(_pygetm.interp_xy, ioffset=0, joffset=0)
elif self.ioffset == target.ioffset + 1 and self.joffset == target.joffset + 1:
# from T to X
ip = functools.partial(_pygetm.interp_xy, ioffset=1, joffset=1)
elif self.ioffset == target.ioffset - 1 and self.joffset == target.joffset + 1:
# from V to U (i=-1 and j=0 undefined)
ip = functools.partial(_pygetm.interp_xy, ioffset=0, joffset=1)
elif self.ioffset == target.ioffset + 1 and self.joffset == target.joffset - 1:
# from U to V (i=0 and j=-1 undefined)
ip = functools.partial(_pygetm.interp_xy, ioffset=1, joffset=0)
elif self.ioffset == target.ioffset - 2 and self.joffset == target.joffset:
# from T to UU (no interpolation, just copy slice)
assert self.nx == target.nx and self.ny == target.ny
ip = functools.partial(
_assign,
xslice=(Ellipsis, slice(1, None)),
yslice=(Ellipsis, slice(0, -1)),
)
elif self.ioffset == target.ioffset and self.joffset == target.joffset - 2:
# from T to VV (no interpolation, just copy slice)
assert self.nx == target.nx and self.ny == target.ny
ip = functools.partial(
_assign,
xslice=(Ellipsis, slice(1, None), slice(None)),
yslice=(Ellipsis, slice(0, -1), slice(None)),
)
else:
raise NotImplementedError(
f"Cannot interpolate from grid type {self.postfix} "
f"to grid type {target.postfix}"
)
self._interpolators[target] = ip
return ip
@property
def gradient_x_calculator(self):
target_grid = self.ugrid
assert (self.ioffset + 1 - target_grid.ioffset) % 2 == 0
assert (self.joffset - target_grid.joffset) % 2 == 0
ioffset = (self.ioffset + 1 - target_grid.ioffset) // 2
joffset = (self.joffset - target_grid.joffset) // 2
assert ioffset in (0, 1)
assert joffset in (-1, 0, 1)
return target_grid, functools.partial(
_pygetm.gradient_x, self.idx.all_values, ioffset=ioffset, joffset=joffset
)
@property
def gradient_y_calculator(self):
target_grid = self.vgrid
assert (self.ioffset - target_grid.ioffset) % 2 == 0
assert (self.joffset + 1 - target_grid.joffset) % 2 == 0
ioffset = (self.ioffset - target_grid.ioffset) // 2
joffset = (self.joffset + 1 - target_grid.joffset) // 2
assert ioffset in (-1, 0, 1)
assert joffset in (0, 1)
return target_grid, functools.partial(
_pygetm.gradient_y, self.idy.all_values, ioffset=ioffset, joffset=joffset
)
[docs]
def rotate(
self, u: npt.ArrayLike, v: npt.ArrayLike, to_grid: bool = True
) -> tuple[npt.ArrayLike, npt.ArrayLike]:
"""Rotate a geocentric velocity field to the model coordinate system,
or a model velocity field to the geocentric coordinate system.
Args:
u: velocity in x-direction in source coordinate system
(eastward velocity if the source is a geocentric velocity field)
v: velocity in y-direction in source coordinate system
(northward velocity if the source is a geocentric velocity field)
to_grid: rotate from geocentric to model coordinate system, not vice versa
"""
if self._rotator is None:
return u, v
return self._rotator(u, v, to_grid=to_grid)
[docs]
def array(self, *args, **kwargs) -> "Array":
return Array.create(self, *args, **kwargs)
[docs]
def global_to_local(
self, i: int, j: int, *, include_halos: bool = False
) -> tuple[Optional[int], Optional[int]]:
"""Convert global indices (i, j) to local indices in the subdomain."""
xoffset = 0 if self.tiling is None else self.tiling.xoffset
yoffset = 0 if self.tiling is None else self.tiling.yoffset
halox = self.halox if include_halos else 0
haloy = self.haloy if include_halos else 0
i_loc = i - xoffset + halox
j_loc = j - yoffset + haloy
nx = self.nx + 2 * halox
ny = self.ny + 2 * halox
if i_loc < 0 or j_loc < 0 or i_loc >= nx or j_loc >= ny:
return None, None
return i_loc, j_loc
[docs]
def get_gather_info(
self,
shape: tuple[int, ...],
on_boundary: bool,
dtype: npt.DTypeLike,
fill_value,
):
def gather_serial(
locvalues, globvalues: Optional[np.ndarray] = None, globslice=()
):
if globvalues is not None:
globvalues[globslice] = locvalues
return globvalues
return locvalues
if not shape:
# scalar field (e.g., river flow or loading)
global_shape = ()
interior_slice = ()
elif on_boundary:
# boundary field
global_shape = (self.open_boundaries.np_glob,) + shape[1:]
interior_slice = (slice(None),) * len(shape)
else:
# field with trailing y, x dimensions
assert shape[-1] == self.nx_ and shape[-2] == self.ny_
nx = self.tiling.nx_glob + self.overlap
ny = self.tiling.ny_glob + self.overlap
global_shape = shape[:-2] + (ny, nx)
interior_slice = self._interior
if self.tiling.n == 1:
return gather_serial, interior_slice, global_shape
if not shape:
# scalar field (e.g., river flow or loading)
raise NotImplementedError()
elif on_boundary:
# boundary field
gatherer = parallel.GatherFromIndices(
self.tiling.comm,
self.open_boundaries.local_to_global_indices,
(self.open_boundaries.np_glob,),
shape[1:],
dtype,
fill_value=fill_value,
trailing_index=False,
)
else:
gatherer = parallel.Gather(
self.tiling,
shape[:-2] + (self.ny, self.nx),
dtype,
fill_value=fill_value,
overlap=self.overlap,
)
return gatherer, interior_slice, global_shape
def _get_dims(self, ndim: int, z: bool, on_boundary: bool = False) -> Iterable[str]:
if ndim > 0:
if on_boundary:
yield f"bdy{self.postfix}"
if z:
yield "zi" if z == INTERFACES else "z"
if not on_boundary:
yield f"y{self.postfix}"
yield f"x{self.postfix}"
def _get_coords(
self, ndim: int, z: bool, on_boundary: bool = False
) -> Iterable["Array"]:
if ndim >= 2 and not on_boundary:
for array in self.horizontal_coordinates:
yield array
if z:
z_src = self.open_boundaries if on_boundary else self
yield z_src.zf if z == INTERFACES else z_src.zc
for membername in Grid._all_arrays:
info = Grid._array_args.get(membername[1:], {})
long_name = info.get("long_name")
units = info.get("units")
doc = ""
if long_name:
doc = long_name
if units:
doc += f" ({units})"
setattr(Grid, membername[1:], property(operator.attrgetter(membername), doc=doc))
[docs]
class Array(_pygetm.Array, numpy.lib.mixins.NDArrayOperatorsMixin):
__slots__ = (
"_xarray",
"_scatter",
"_gather",
"_name",
"attrs",
"_fill_value",
"saved",
"_shape",
"_ndim",
"_size",
"_dtype",
"values",
"halo_updaters",
"open_boundaries",
)
grid: Grid
def __init__(
self,
name: Optional[str] = None,
units: Optional[str] = None,
long_name: Optional[str] = None,
fill_value: Optional[Union[float, int]] = None,
shape: Optional[tuple[int, ...]] = None,
dtype: Optional[npt.DTypeLike] = None,
grid: Grid = None,
fabm_standard_name: Optional[str] = None,
attrs: Mapping[str, Any] = {},
):
_pygetm.Array.__init__(self, grid)
self._xarray: Optional[xr.DataArray] = None
self._scatter: Optional[parallel.Scatter] = None
self._gather: Optional[parallel.Gather] = None
assert (
fill_value is None or np.ndim(fill_value) == 0
), "fill_value must be a scalar value"
self._name = name
self.attrs: Mapping[str, Any] = attrs.copy()
if units:
self.attrs["units"] = units
if long_name:
self.attrs["long_name"] = long_name
if fabm_standard_name:
self.attrs.setdefault("_fabm_standard_names", set()).add(fabm_standard_name)
self._fill_value = (
fill_value
if fill_value is None or dtype is None
else np.array(fill_value, dtype=dtype)
)
self.saved = False #: to be set if this variable is requested for output
self._shape = shape
self._ndim = None if shape is None else len(shape)
self._size = None if shape is None else np.prod(shape)
self._dtype = dtype
self.values = None
class Proxy(Sequence[parallel.BaseHaloUpdater]):
def __init__(self, target: "Array", att: str):
self.target = target
self.att = att
def _create(self) -> Sequence[parallel.BaseHaloUpdater]:
updaters = parallel.create_halo_updaters(
self.target.grid.tiling,
self.target.all_values,
self.target.grid.halox,
self.target.grid.haloy,
overlap=self.target.grid.overlap,
)
setattr(self.target, self.att, updaters)
return updaters
def __getitem__(self, key: int) -> parallel.BaseHaloUpdater:
return self._create()[key]
def __len__(self) -> int:
return len(self._create())
self.halo_updaters: Sequence[parallel.BaseHaloUpdater] = Proxy(
self, "halo_updaters"
)
[docs]
def set_fabm_standard_name(self, fabm_standard_name):
self.attrs.setdefault("_fabm_standard_names", set()).add(fabm_standard_name)
fabm_standard_name = property(fset=set_fabm_standard_name)
[docs]
def mirror(self, target: Optional["Array"] = None):
target = target or self
m = self.grid._mirrors.get(target.grid)
if m is not None:
source_slice, target_slice = m
target.all_values[target_slice] = self.all_values[source_slice]
[docs]
def finish_initialization(self):
"""This is called by the underlying cython implementation after the array
receives a value (:attr:`all_values` is valid)
"""
assert self.grid is not None
self._dtype = self.all_values.dtype
self._ndim = self.all_values.ndim
if self._fill_value is not None:
# Cast fill value to dtype of the array
self._fill_value = np.array(self._fill_value, dtype=self._dtype)
if self.on_boundary or self._ndim == 0:
# boundary array or scalar
self.values = self.all_values
else:
self.values = self.all_values[self.grid._interior]
self._shape = self.values.shape
self._size = self.values.size
[docs]
def update_halos(self, group: parallel.Neighbor = parallel.Neighbor.ALL):
self.halo_updaters[group]()
[docs]
def compare_halos(self, group: parallel.Neighbor = parallel.Neighbor.ALL):
return self.halo_updaters[group].compare()
[docs]
def register(self):
assert self.grid is not None
if self._name is not None:
if self._name in self.grid.fields:
raise Exception(
f"A field with name {self._name!r} has already been registered"
" with the field manager."
)
self.grid.fields[self._name] = self
def __repr__(self) -> str:
return super().__repr__() + self.grid.postfix
[docs]
def scatter(self, global_data: Optional[np.ndarray]):
if self.grid.tiling.n == 1:
if self.grid.tiling.rank == 0:
self.values[...] = global_data
return
if self._scatter is None:
self._scatter = parallel.Scatter(
self.grid.tiling,
self.all_values,
halox=self.grid.halox,
haloy=self.grid.haloy,
share=self.grid.overlap,
fill_value=self._fill_value,
)
self._scatter(global_data)
[docs]
def gather(self, out: Optional[np.ndarray] = None) -> Optional[np.ndarray]:
if self._gather is None:
gatherer, interior_slice, _ = self.grid.get_gather_info(
self.all_values.shape, self.on_boundary, self.dtype, self._fill_value
)
self._gather = functools.partial(gatherer, self.all_values[interior_slice])
return self._gather(out)
[docs]
def allgather(self) -> np.ndarray:
if self.grid.tiling.n == 1:
return self.values
return self.grid.tiling.comm.Bcast(self.gather())
[docs]
def global_sum(
self,
reproducible: bool = False,
where: Optional["Array"] = None,
to_all: bool = False,
) -> Optional[np.ndarray]:
if reproducible:
assert not to_all
all = self.gather()
if where is not None:
where = where.gather()
if all is not None:
return all.sum(where=np._NoValue if where is None else where)
else:
local_sum = self.values.sum(
where=np._NoValue if where is None else where.values
)
tiling = self.grid.tiling
reduce = tiling.allreduce if to_all else tiling.reduce
return reduce(local_sum)
[docs]
def global_mean(
self, reproducible: bool = False, where: Optional["Array"] = None
) -> Optional[np.ndarray]:
sum = self.global_sum(reproducible=reproducible, where=where)
if where is not None:
count = where.global_sum()
else:
count = self.grid.tiling.reduce(self.values.size)
if sum is not None:
return sum / count
[docs]
@staticmethod
def create(
grid: Grid,
fill: Optional[npt.ArrayLike] = None,
z: Literal[None, True, False, CENTERS, INTERFACES] = None,
dtype: npt.DTypeLike = None,
on_boundary: bool = False,
register: bool = True,
**kwargs,
) -> "Array":
"""Create a new :class:`Array`
Args:
grid: grid associated with the new array
fill: value to set the new array to
z: vertical dimension of the new array.
``False`` for a 2D array, ``CENTERS`` (or ``True``) for an array
defined at the layer centers, ``INTERFACES`` for an array defined at
the layer interfaces. ``None`` to detect from ``fill``.
dtype: data type
on_boundary: whether to describe data along the open boundaries (1D),
instead of the 2D x-y model domain
register: whether to register the array as field available for output
**kwargs: additional keyword arguments passed to :class:`Array`
"""
ar = Array(grid=grid, **kwargs)
if fill is None and ar.fill_value is not None:
fill = ar.fill_value
if fill is not None:
fill = np.asarray(fill)
if z is None and not on_boundary:
if fill.ndim != 3:
z = False
elif fill.shape[0] == grid.nz_ + 1:
z = INTERFACES
else:
z = CENTERS
if dtype is None:
dtype = float if fill is None else fill.dtype
shape = [grid.open_boundaries.np] if on_boundary else [grid.ny_, grid.nx_]
if z:
nz = grid.nz_ + 1 if z == INTERFACES else grid.nz_
shape.insert(1 if on_boundary else 0, nz)
data = ar.allocate(shape, dtype)
if fill is not None:
data[...] = fill
ar.wrap_ndarray(data, on_boundary=on_boundary, register=register)
return ar
@property
def all_shape(self) -> tuple[int, ...]:
shape = (
[self.grid.open_boundaries.np]
if self.on_boundary
else [self.grid.ny_, self.grid.nx_]
)
if self.z:
nz = self.grid.nz_ + 1 if self.z == INTERFACES else self.grid.nz_
shape.insert(1 if self.on_boundary else 0, nz)
return tuple(shape)
[docs]
def fill(self, value):
"""Set array to specified value, while respecting the mask: masked points are
set to :attr:`fill_value`
"""
try:
self.all_values = value
except ValueError:
# Incorrect shape for all_values, try values (excl halos) instead
self.values[...] = value
self.update_halos()
if self.fill_value is not None:
self.all_values[self.all_mask] = self.fill_value
@property
def all_mask(self) -> np.ndarray:
"""Boolean array indicating invalid data points, including halos"""
if self._ndim == 0:
return np.array(False)
valid_mask_values = self.attrs.get("_valid_at", ()) + (CellType.ACTIVE,)
mask = self.grid.get_mask(valid_mask_values, self.z)
if self.on_boundary:
open_boundaries = self.grid.open_boundaries
mask = mask[..., open_boundaries.j, open_boundaries.i].T
return np.broadcast_to(mask, self.all_shape)
@property
def mask(self) -> np.ndarray:
"""Boolean array indicating invalid data points, excluding halos"""
if self.on_boundary or self._ndim == 0:
return self.all_mask
else:
return self.all_mask[self.grid._interior]
@property
def ma(self) -> np.ma.MaskedArray:
"""Masked array representation that combines the data and the mask associated
with the array's native grid
"""
return np.ma.array(self.values, mask=self.mask)
[docs]
def plot(self, mask: bool = True, **kwargs):
"""Plot the array with :meth:`xarray.DataArray.plot`
Args:
**kwargs: additional keyword arguments passed to
:meth:`xarray.DataArray.plot`
"""
kwargs.setdefault("shading", "auto")
if self.grid.horizontal_coordinates:
x, y = self.grid.horizontal_coordinates
kwargs.setdefault("x", x.name)
kwargs.setdefault("y", y.name)
return self.as_xarray(mask=mask).plot(**kwargs)
[docs]
def interp(
self,
target: Union["Array", Grid],
z: Literal[None, True, False, CENTERS, INTERFACES] = None,
) -> "Array":
"""Interpolate the array to another grid.
Args:
target: either the :class:`Array` that will hold the interpolated data,
or the :class:`~pygetm.core.Grid` to interpolate to. If a ``Grid`` is
provided, a new array will be created to hold the interpolated values.
"""
if not isinstance(target, Array):
# Target must be a grid; we need to create the array
target_z = z if z is not None else self.z
target = Array.create(target, dtype=self._dtype, z=target_z)
source_array = self.all_values
target_array = target.all_values
if self.grid is target.grid:
if self.z == INTERFACES and target.z == CENTERS:
# vertical interpolation from layer interfaces to layer centers
_pygetm.interp_z(source_array, target_array, offset=0)
elif self.z == CENTERS and target.z == INTERFACES:
# vertical interpolation from layer centers to layer interfaces
# (top and bottom interfaces will be left untouched)
_pygetm.interp_z(source_array, target_array, offset=1)
else:
if self._ndim == 2:
source_array = source_array[None, ...]
target_array = target_array[None, ...]
interpolate = self.grid.interpolator(target.grid)
interpolate(source_array, target_array)
return target
[docs]
def gradient_x(self, target: Optional["Array"] = None) -> "Array":
target_grid, calculator = self.grid.gradient_x_calculator
if target is None:
target = Array.create(target_grid, dtype=self._dtype, z=self.z)
calculator(self.all_values, target.all_values)
return target
[docs]
def gradient_y(self, target: Optional["Array"] = None) -> "Array":
target_grid, calculator = self.grid.gradient_y_calculator
if target is None:
target = Array.create(target_grid, dtype=self._dtype, z=self.z)
calculator(self.all_values, target.all_values)
return target
[docs]
def __array__(self, dtype: Optional[npt.DTypeLike] = None, copy=None) -> np.ndarray:
"""Return interior of the array as a NumPy array.
No copy will be made unless the requested data type differs from that
of the underlying array.
Args:
dtype: data type
"""
return np.asarray(self.values, dtype=dtype)
[docs]
def isel(self, *, z: int, **kwargs) -> "Array":
"""Select a single depth level. The data in the returned 2D :class:`Array`
will be a view of the relevant data of the original 3D array. Thus, changes
in one will affect the other.
"""
if self._ndim != 3:
raise NotImplementedError
if self.units is not None:
kwargs.setdefault("units", self.units)
if self.long_name is not None:
kwargs.setdefault("long_name", f"{self.long_name} @ k={z}")
kwargs["attrs"] = kwargs.get("attrs", {}).copy()
for att in ("_mask_output", "_valid_at"):
if att in self.attrs:
kwargs["attrs"][att] = self.attrs[att]
ar = Array(grid=self.grid, fill_value=self.fill_value, **kwargs)
ar.wrap_ndarray(self.all_values[z, ...])
return ar
[docs]
def __getitem__(self, key) -> np.ndarray:
"""Retrieve values from the interior of the array (excluding halos).
For access to the halos, use :attr:`all_values`.
"""
return self.values[key]
[docs]
def __setitem__(self, key, values):
"""Assign values to the interior of the array (excluding halos).
For access to the halos, use :attr:`all_values`.
"""
self.values[key] = values
@property
def shape(self) -> tuple[int, ...]:
"""Shape excluding halos"""
return self._shape
@property
def ndim(self) -> int:
"""Number of dimensions"""
return self._ndim
@property
def z(self):
"""Vertical dimension: ``False`` if the array has no vertical dimension,
``CENTERS`` for layer centers, ``INTERFACES`` for layer interfaces.
"""
if self._ndim != (2 if self.on_boundary else 3):
return False
nz = self._shape[1 if self.on_boundary else 0]
return INTERFACES if nz == self.grid.nz_ + 1 else CENTERS
@property
def size(self) -> int:
"""Total number of values, excluding halos"""
return self._size
@property
def dtype(self) -> npt.DTypeLike:
"""Data type"""
return self._dtype
@property
def name(self) -> Optional[str]:
"""Name"""
return self._name
@property
def units(self) -> Optional[str]:
"""Units"""
return self.attrs.get("units")
@property
def long_name(self) -> Optional[str]:
"""Long name"""
return self.attrs.get("long_name") or self.name
@property
def fill_value(self) -> Optional[Union[int, float]]:
"""Fill value"""
return self._fill_value
# Below based on https://np.org/devdocs/reference/generated/np.lib.mixins.NDArrayOperatorsMixin.html#np.lib.mixins.NDArrayOperatorsMixin
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
if method != "__call__":
return NotImplemented
out = kwargs.get("out", ())
for x in inputs + out:
# Only support operations with instances of _HANDLED_TYPES.
# Use ArrayLike instead of type(self) for isinstance to
# allow subclasses that don't override __array_ufunc__ to
# handle ArrayLike objects.
if not isinstance(x, (np.ndarray, numbers.Number, Array)):
return NotImplemented
if isinstance(x, Array) and x.grid is not self.grid:
return NotImplemented
# Defer to the implementation of the ufunc on unwrapped values.
inputs = tuple(x.all_values if isinstance(x, Array) else x for x in inputs)
if out:
kwargs["out"] = tuple(
x.all_values if isinstance(x, Array) else x for x in out
)
result = getattr(ufunc, method)(*inputs, **kwargs)
if type(result) is tuple:
# multiple return values
return tuple(self.create(self.grid, x) for x in result)
elif method == "at":
# no return value
return None
else:
# one return value
return self.create(self.grid, result)
[docs]
def set(self, value: Union[float, np.ndarray, xr.DataArray], **kwargs):
"""Link this array to a field or value using
:attr:`~pygetm.domain.Domain.input_manager`, which will perform temporal and
spatial interpolation as required.
Args:
value: value to assign to this array. If it is time-dependent (if you pass
an instance of :class:`xarray.DataArray` with a time dimension),
the array's value will be updated during the simulation whenever
:meth:`pygetm.input.InputManager.update` is called.
**kwargs: keyword arguments passed to :meth:`pygetm.input.InputManager.add`
"""
self.grid.input_manager.add(self, value, **kwargs)
[docs]
def require_set(self, logger: Optional[logging.Logger] = None):
"""Assess whether all non-masked cells of this field have been set. If not, an
error message is written to the log and False is returned.
"""
valid = True
if self._fill_value is not None:
invalid = self.ma == self._fill_value
if invalid.any():
(logger or logging.getLogger()).error(
f"{self.name} is masked ({self._fill_value})"
f" in {invalid.sum()} active grid cells."
)
valid = False
return valid
[docs]
def as_xarray(self, mask: bool = False) -> xr.DataArray:
"""Return this array wrapped in an :class:`xarray.DataArray` that includes
coordinates and can be used for plotting
"""
if self._xarray is not None and not mask:
return self._xarray
coords = {}
possible_c = (self.grid.x, self.grid.y, self.grid.lon, self.grid.lat)
if not any(c is self for c in possible_c):
for c in possible_c:
if c is not None:
coords[c.name] = c.xarray
dims = self.grid._get_dims(self._ndim, self.z, self.on_boundary)
values = self.values if not mask else self.ma
_xarray = xr.DataArray(
values, coords=coords, dims=dims, attrs=self.attrs, name=self.name
)
if not mask:
self._xarray = _xarray
return _xarray
xarray = property(as_xarray)