from typing import (
Iterable,
MutableMapping,
Union,
Optional,
Mapping,
Literal,
Callable,
Any,
NamedTuple,
TypeVar,
)
import collections
import enum
import functools
import numpy as np
from numpy.typing import DTypeLike, ArrayLike
import pygetm.core
import pygetm._pygetm
import pygetm.util.interpolate
import pygetm.parallel
from pygetm.constants import (
CENTERS,
INTERFACES,
TimeVarying,
CoordinateType,
EdgeTreatment,
)
[docs]
class GridInfo(NamedTuple):
grid: pygetm.core.Grid
z: Optional[Literal[CENTERS, INTERFACES]]
on_boundary: bool
def _get_gather_info(
self, source: "Base"
) -> tuple[Callable, tuple, Mapping[str, Any]]:
gatherer, local_slice, global_shape = self.grid.get_gather_info(
shape=source.shape,
on_boundary=self.on_boundary,
dtype=source.dtype,
fill_value=source.fill_value,
)
return gatherer, local_slice, dict(shape=global_shape)
[docs]
def get_coords(self, ndim: int) -> Iterable["Base"]:
for array in self.grid._get_coords(ndim, self.z, self.on_boundary):
yield Field(array)
yield from self.grid.extra_output_coordinates
T = TypeVar("T")
[docs]
class Base:
__slots__ = (
"expression",
"dtype",
"shape",
"ndim",
"dims",
"fill_value",
"attrs",
"time_varying",
"coordinates",
"_global_values",
)
[docs]
@classmethod
def parameterize(cls: type[T], **kwargs) -> T:
return functools.partial(cls, **kwargs)
def __init__(
self,
expression: str,
shape: Iterable[int],
dims: Iterable[str],
dtype: DTypeLike,
fill_value=None,
time_varying: Union[TimeVarying, Literal[False]] = TimeVarying.MICRO,
attrs: Mapping[str, Any] = {},
):
self.expression = expression
self.shape = tuple(shape)
self.ndim = len(self.shape)
self.dims = tuple(dims)
assert self.ndim == len(
self.dims
), f"Expected {self.ndim} dimensions but got {self.dims}"
self.dtype = np.dtype(dtype)
self.fill_value = fill_value
self.attrs = attrs
self.time_varying = time_varying
self.coordinates: list[str] = []
[docs]
def get(
self, out: Optional[ArrayLike] = None, slice_spec: tuple[int, ...] = ()
) -> ArrayLike:
raise NotImplementedError
@property
def default_name(self) -> str:
raise NotImplementedError
[docs]
def gather(self) -> "Base":
return Gather(self)
@property
def updatable(self) -> bool:
return False
@property
def updater(self) -> Optional[Callable]:
raise NotImplementedError
@property
def grid_info(self) -> Optional[GridInfo]:
return None
def _get_gather_info(
self, source: "Base"
) -> tuple[Callable, tuple, Mapping[str, Any]]:
return self.grid_info._get_gather_info(source)
@property
def mask(self) -> np.ndarray:
return np.broadcast_to(False, self.shape)
@property
def coords(self) -> Iterable["Base"]:
yield from self.grid_info.get_coords(self.ndim)
[docs]
class WrappedArray(Base):
__slots__ = ("_name", "values", "_global_field")
def __init__(
self,
values: np.ndarray,
name: str,
dims: tuple[str, ...],
global_field: Optional[Base] = None,
**kwargs,
):
super().__init__(
name, values.shape, dims, values.dtype, time_varying=False, **kwargs
)
self._name = name
self.values = values
self._global_field = global_field or self
[docs]
def gather(self) -> Base:
return self._global_field
@property
def default_name(self) -> str:
return self._name
@property
def coords(self) -> Iterable[Base]:
return ()
[docs]
def get(
self, out: Optional[ArrayLike] = None, slice_spec: tuple[int, ...] = ()
) -> ArrayLike:
if out is None:
return self.values
else:
out[slice_spec] = self.values
return out
[docs]
class Updatable(enum.Enum):
ALWAYS = 1
MACRO_ONLY = 2
[docs]
class FieldCollection(Mapping[str, Base]):
def __init__(
self,
available_fields: Mapping[str, pygetm.core.Array],
default_dtype: Optional[DTypeLike] = None,
sub: bool = False,
):
self.fields: MutableMapping[str, Base] = collections.OrderedDict()
self.expression2name: MutableMapping[str, str] = {}
self.available_fields = available_fields
self.default_dtype = default_dtype
self.sub = sub
self._updaters = {}
def __getitem__(self, key: str) -> Base:
return self.fields[key]
def __iter__(self):
return iter(self.fields)
def __len__(self) -> int:
return len(self.fields)
[docs]
def request(
self,
*fields: Union[str, pygetm.core.Array],
output_name: Optional[str] = None,
dtype: Optional[DTypeLike] = None,
mask: Optional[bool] = None,
time_average: bool = False,
grid: Optional[pygetm.core.Grid] = None,
z: Union[
None, Literal[CENTERS], Literal[INTERFACES], float, Iterable[float]
] = None,
generate_unique_name: bool = False,
transforms: Iterable[type["UnivariateTransform"]] = (),
) -> tuple[str, ...]:
"""Add one or more arrays to this field collection.
Args:
*fields: names of arrays or array objects to add. When names are provided,
they will be looked up in the field manager.
output_name: name to use for this field. This can only be provided if a
single field is requested.
dtype: data type of the field to use. Array values will be cast to this data
type whenever the field is saved
mask: whether to explicitly set masked values to the array's fill value when
saving. If this argument is not provided, masking behavior is determined
by the array's _mask_output flag.
time_average: whether to time-average the field
grid: if provided, the field will be regridded to this grid before saving
z: if provided, the field will be interpolated to these z levels before saving.
This can be CENTERS or INTERFACES to indicate interpolation to the center
or interfaces of the model grid, a 1D array of custom z levels, or a single
float indicating the z level to interpolate to. Note that z levels are
negative downward, with the bottom being -H.
generate_unique_name: whether to generate a unique output name for requested
fields if a field with the same name has previously been added to the
collection. If this is not set and a field with this name was added
previously, an exception will be raised.
Returns:
tuple with names of the newly added fields
"""
if not fields:
raise Exception(
"One or more positional arguments must be provided"
"to specify the field(s) requested."
)
# For backward compatibility: a tuple of names could be provided
if len(fields) == 1 and isinstance(fields[0], tuple):
fields = fields[0]
arrays = []
for field in fields:
if isinstance(field, str):
if field not in self.available_fields:
raise Exception(
f"Unknown field {field!r} requested."
f" Available: {', '.join(self.available_fields)}"
)
arrays.append(self.available_fields[field])
elif isinstance(field, pygetm.core.Array):
arrays.append(field)
else:
raise Exception(
f"Incorrect field specification {field!r}."
" Expected a string or an object of type pygetm.core.Array."
)
if len(arrays) > 1 and output_name is not None:
raise Exception(
f"Trying to add multiple fields to {self!r}."
" In this case, output_name cannot be specified."
)
names = []
for array in arrays:
name = output_name
if name is None:
if array.name is None:
raise Exception(
f"Trying to add an unnamed variable to {self!r}."
" In this case, output_name must be provided"
)
name = array.name
mask_current = mask
if mask_current is None:
mask_current = array.attrs.get("_mask_output", False)
array.saved = True
source_grid = array.grid
if dtype is None and array.dtype == float:
dtype = self.default_dtype
field = Field(array, dtype=dtype)
for tf in transforms:
field = tf(field)
if time_average and field.time_varying:
field = TimeAverage(field)
if grid and array.grid is not grid:
field = Regrid(field, grid=grid)
if z is not None and array.z:
if isinstance(z, (Iterable, float)):
field = InterpZ(field, z, "z1")
elif array.z and array.z != z:
field = Regrid(field, z=z)
if time_average or mask_current or grid:
field = Mask(field)
if not self.sub:
field = field.gather()
for tf in source_grid.default_output_transforms:
field = tf(field)
names.append(self._add_field(field, name, generate_unique_name))
return tuple(names)
def _add_field(self, field: Base, name: str, generate_unique_name: bool):
final_name = name
if generate_unique_name:
i = 0
while final_name in self.fields:
final_name = f"{name}_{i}"
i += 1
elif final_name in self.fields:
raise Exception(
f"A variable with name {name!r} has already been added to {self!r}."
)
if field.updatable:
triggering_macro_values = [True]
if field.updatable == Updatable.ALWAYS:
triggering_macro_values.append(False)
for key in triggering_macro_values:
self._updaters.setdefault(key, []).append(field.updater)
self.fields[final_name] = field
self.expression2name[field.expression] = final_name
return final_name
[docs]
def add_coordinates(self):
for field in list(self.fields.values()):
field.coordinates.extend(self.require(f) for f in field.coords)
[docs]
def require(self, field: Base) -> str:
"""Ensure that the specified variable (or expression of variables) is included
in the field collection. This is typically used to add coordinate variables.
Args:
expression: variable name or expression of variable(s)
"""
if field.expression in self.expression2name:
return self.expression2name[field.expression]
return self._add_field(field, field.default_name, generate_unique_name=True)
[docs]
def update(self, macro: bool = False):
for updater in self._updaters.get(macro, ()):
updater()
[docs]
class Field(Base):
__slots__ = "collection", "array"
def __init__(self, array: pygetm.core.Array, dtype: Optional[DTypeLike] = None):
attrs = {}
for key, value in array.attrs.items():
if not key.startswith("_"):
attrs[key] = value
if "_global_values" in array.attrs:
self._global_values = array.attrs["_global_values"]
self.array = array
default_time_varying = TimeVarying.MACRO if array.z else TimeVarying.MICRO
time_varying = array.attrs.get("_time_varying", default_time_varying)
shape = list(self.array.shape)
if array.ndim >= 2 and not array.on_boundary:
shape[-1] += 2 * array.grid.halox
shape[-2] += 2 * array.grid.haloy
super().__init__(
array.name,
shape,
array.grid._get_dims(array.ndim, array.z, array.on_boundary),
dtype or array.dtype,
array.fill_value,
time_varying,
attrs,
)
[docs]
def get(
self, out: Optional[ArrayLike] = None, slice_spec: tuple[int, ...] = ()
) -> ArrayLike:
if out is None:
return self.array.all_values
else:
out[slice_spec] = self.array.all_values
return out
@property
def default_name(self) -> str:
return self.array.name
@property
def grid_info(self) -> GridInfo:
return GridInfo(self.array.grid, self.array.z, self.array.on_boundary)
@property
def mask(self) -> np.ndarray:
return self.array.all_mask
[docs]
class Gather(UnivariateTransform):
__slots__ = "root_has_global_values", "_slice", "_gather"
def __init__(self, source: Base):
gatherer, local_slice, kwargs = source._get_gather_info(source)
super().__init__(source, expression=source.expression, **kwargs)
self.root_has_global_values = hasattr(source, "_global_values")
self._slice = local_slice
self._gather = gatherer
[docs]
def get(
self, out: Optional[ArrayLike] = None, slice_spec: tuple[int, ...] = ()
) -> ArrayLike:
if self.root_has_global_values:
global_values = self._source._global_values
if global_values is not None:
assert global_values.shape == self.shape
if out is None:
return global_values
out[slice_spec] = global_values
else:
local_interior = self._source.get()[self._slice]
out = self._gather(local_interior, out, slice_spec)
return out
@property
def coords(self) -> Iterable[Base]:
for c in self._source.coords:
yield c.gather()
@property
def grid_info(self) -> None:
return None
[docs]
class Mask(UnivariateTransformWithData):
def __init__(self, source: Field):
super().__init__(source)
self._mask = source.mask
assert self._mask.shape == self.shape[-self._mask.ndim :]
[docs]
def get(
self, out: Optional[ArrayLike] = None, slice_spec: tuple[int, ...] = ()
) -> ArrayLike:
self._source.get(out=self.values)
self.values[..., self._mask] = self.fill_value
return super().get(out, slice_spec)
[docs]
class TimeAverage(UnivariateTransformWithData):
__slots__ = ("_n",)
def __init__(self, source: Field):
super().__init__(source)
self._n = 0
if "cell_methods" in self.attrs:
self.attrs["cell_methods"] += " time: mean"
else:
self.attrs["cell_methods"] = "time: mean"
self.values.fill(self.fill_value)
@property
def updatable(self) -> bool:
if self._source.time_varying == TimeVarying.MACRO:
return Updatable.MACRO_ONLY
return Updatable.ALWAYS
@property
def updater(self) -> Optional[Callable]:
return self.update
[docs]
def update(self):
if self._n == 0:
self._source.get(out=self.values)
else:
self.values += self._source.get()
self._n += 1
[docs]
def get(
self, out: Optional[ArrayLike] = None, slice_spec: tuple[int, ...] = ()
) -> ArrayLike:
if self._n > 0:
self.values *= 1.0 / self._n
self._n = 0
return super().get(out, slice_spec)
@property
def coords(self) -> Iterable[Base]:
for c in super().coords:
if c.time_varying:
c = TimeAverage(c)
yield c
@property
def default_name(self) -> str:
return self._source.default_name + "_av"
[docs]
class Regrid(UnivariateTransformWithData):
__slots__ = ("interpolate", "_grid_info", "_slice")
def __init__(
self,
source: Base,
grid: Optional[pygetm.core.Grid] = None,
z: Optional[Literal[None, CENTERS, INTERFACES]] = None,
):
source_grid_info = source.grid_info
assert source_grid_info is not None
assert not source_grid_info.on_boundary
grid = grid or source_grid_info.grid
if grid is not source_grid_info.grid:
assert z is None
self.interpolate = source_grid_info.grid.interpolator(grid)
shape = source.shape[:-2] + (grid.ny_, grid.nx_)
args = f", grid={grid.postfix}"
if source.ndim > 2:
z = source_grid_info.z
else:
assert z is not None
if z == CENTERS:
assert source_grid_info.z == INTERFACES
self.interpolate = functools.partial(pygetm._pygetm.interp_z, offset=0)
shape = (source.shape[0] - 1,) + source.shape[1:]
args = ", z=centers"
else:
assert source_grid_info.z == CENTERS
self.interpolate = functools.partial(pygetm._pygetm.interp_z, offset=1)
shape = (source.shape[0] + 1,) + source.shape[1:]
args = ", z=interfaces"
super().__init__(
source,
shape=shape,
dims=grid._get_dims(len(shape), z),
expression=f"{self.__class__.__name__}({source.expression}{args})",
inherit_grid_info=False,
)
self._slice = (np.newaxis, Ellipsis) if source.ndim < 3 else (Ellipsis,)
self._grid_info = GridInfo(grid, z, False)
[docs]
def get(
self, out: Optional[ArrayLike] = None, slice_spec: tuple[int, ...] = ()
) -> ArrayLike:
self.interpolate(self._source.get()[self._slice], self.values[self._slice])
return super().get(out, slice_spec)
@property
def grid_info(self) -> GridInfo:
return self._grid_info
@property
def mask(self) -> np.ndarray:
source_mask = self._source.mask.astype(float, order="C")
mask = np.zeros(self.shape, dtype=float)
self.interpolate(source_mask[self._slice], mask[self._slice])
return mask > 0.99
[docs]
class InterpZ(UnivariateTransformWithData):
"""Interpolate in the vertical.
The vertical dimension must be the first dimension of the source array.
The source array must have a vertical coordinate with ``axis`` attribute equal
to ``Z``. This coordinate must have the same shape as the source array.
"""
__slots__ = ("z_src", "z_tgt", "z_dim", "edges")
def __init__(
self,
source: Base,
z: ArrayLike,
dim: str,
edges: EdgeTreatment = EdgeTreatment.MISSING,
):
self.z_tgt = np.asarray(z, dtype=float)
shape = (self.z_tgt.size,) + source.shape[1:]
dims = (dim,) + source.dims[1:]
self.z_dim = dim
self.edges = edges
expression = f"{self.__class__.__name__}({source.expression}, z={self.z_tgt})"
super().__init__(source, shape=shape, dims=dims, expression=expression)
for c in source.coords:
if c.attrs.get("axis") == "Z":
self.z_src = c.get()
break
else:
raise ValueError("Could not find source z coordinate")
[docs]
def get(
self, out: Optional[ArrayLike] = None, slice_spec: tuple[int, ...] = ()
) -> ArrayLike:
ip = pygetm.util.interpolate.LinearVectorized1D(
self.z_tgt, self.z_src, 0, self.fill_value, edges=self.edges
)
self.values[...] = ip(self._source.get())
return super().get(out, slice_spec)
@property
def coords(self) -> Iterable[Base]:
for c in super().coords:
if c.attrs.get("axis") == "Z":
c = WrappedArray(self.z_tgt, self.z_dim, (self.z_dim,), attrs=c.attrs)
yield c
@property
def mask(self) -> np.ndarray:
return self._source.mask.all(axis=0)
[docs]
class IndexXY(UnivariateTransform):
__slots__ = "_i", "_j", "_slice", "_index_dims", "_local2global", "_index_coords"
def __init__(
self,
source: Base,
x: ArrayLike,
y: ArrayLike,
coordinate_type: CoordinateType = CoordinateType.IJ,
dims: tuple[str, ...] = (),
coords: Mapping[str, Union[Base, ArrayLike]] = {},
):
index_shape = np.broadcast_shapes(np.shape(x), np.shape(y))
assert len(dims) == len(index_shape)
self._index_dims = dims
source_grid_info = source.grid_info
assert source_grid_info is not None and not source_grid_info.on_boundary
grid = source_grid_info.grid
self._i = np.empty(index_shape, dtype=np.intp)
self._j = np.empty(index_shape, dtype=np.intp)
if grid.tiling.rank == 0:
all_mask = grid.mask.attrs["_global_values"]
all_x = None if grid.x is None else grid.x.attrs["_global_values"]
all_y = None if grid.y is None else grid.y.attrs["_global_values"]
all_lon = None if grid.lon is None else grid.lon.attrs["_global_values"]
all_lat = None if grid.lat is None else grid.lat.attrs["_global_values"]
loc = pygetm.core.Locator(
all_mask, x=all_x, y=all_y, lon=all_lon, lat=all_lat
)
self._i[...], self._j[...] = loc(x, y, coordinate_type=coordinate_type)
grid.tiling.comm.Bcast(self._i)
grid.tiling.comm.Bcast(self._j)
if hasattr(source, "_global_values"):
# The root (rank=0) has the global values of the source array
if source._global_values is not None:
# We are the root - extract the values at the desired indices
self._global_values = source._global_values[..., self._j, self._i]
else:
# We are not the root - just set the attribute to flag that the root
# does have the global values
self._global_values = None
# Determine which of the desired indices fall within the local subdomain
# (excluding halos)
i = self._i - grid.tiling.xoffset
j = self._j - grid.tiling.yoffset
inside = (i >= 0) & (i < grid.nx) & (j >= 0) & (j < grid.ny)
# Determine slice needed to extract local values of interest, and mapping from
# the resulting local values to the global array with desired indices.
# At this point both local and global arrays are flattened (1D).
i_in = i[inside] + grid.halox
j_in = j[inside] + grid.haloy
self._slice = (Ellipsis, j_in, i_in)
self._local2global = np.flatnonzero(inside)
flat_shape = source.shape[:-2] + self._local2global.shape
flat_dim_name = "_".join(dims) if dims else "station"
flat_dims = source.dims[:-2] + (flat_dim_name,)
super().__init__(
source, shape=flat_shape, dims=flat_dims, inherit_grid_info=False
)
self._index_coords = []
for name, values in coords.items():
values = np.broadcast_to(values, index_shape)
global_field = WrappedArray(values, name, dims=dims)
local_field = WrappedArray(
values[inside], name, dims=(flat_dim_name,), global_field=global_field
)
self._index_coords.append(local_field)
[docs]
def get(
self, out: Optional[ArrayLike] = None, slice_spec: tuple[int, ...] = ()
) -> ArrayLike:
values = self._source.get()[self._slice]
if out is None:
return values
else:
out[slice_spec] = values
return out
@property
def coords(self) -> Iterable[Base]:
for c in self._source.coords:
if c.dims[-2:] == self._source.dims[-2:]:
c = IndexXY(c, self._i, self._j, dims=self._index_dims)
yield c
for c in self._index_coords:
yield c
@property
def mask(self) -> np.ndarray:
return self._source.mask[self._slice]
def _get_gather_info(
self, source: Base
) -> tuple[Callable, tuple, Mapping[str, Any]]:
"""Gather across all processes and reshape the flattened
data to the original index shape"""
class gather_serial:
def __init__(self, final_shape):
self.final_shape = final_shape
def __call__(
self, locvalues, globvalues: Optional[np.ndarray] = None, globslice=()
):
locvalues = np.reshape(locvalues, self.final_shape)
if globvalues is not None:
globvalues[globslice] = locvalues
return globvalues
return locvalues
tiling = self._source.grid_info.grid.tiling
final_shape = source.shape[:-1] + self._i.shape
final_dims = source.dims[:-1] + self._index_dims
if tiling.n == 1:
gatherer = gather_serial(final_shape)
else:
gatherer = pygetm.parallel.GatherFromIndices(
tiling.comm,
self._local2global,
self._i.shape,
source.shape[:-1],
source.dtype,
fill_value=source.fill_value,
)
return gatherer, (Ellipsis,), dict(shape=final_shape, dims=final_dims)
@property
def grid_info(self) -> None:
return None