from typing import Mapping, Optional, Union, Iterable, Any, TYPE_CHECKING
import enum
import functools
import logging
import numpy as np
import numpy.typing as npt
import xarray as xr
from . import core
from . import parallel
from . import rivers
from . import open_boundaries
from .constants import CoordinateType, CellType, GRAVITY, EdgeTreatment
if TYPE_CHECKING:
import matplotlib.figure
import matplotlib.colors
def _get_rectangle_overlap(
istart1: int,
istop1: int,
jstart1: int,
jstop1: int,
istart2: int,
istop2: int,
jstart2: int,
jstop2: int,
):
istart = max(istart1, istart2)
ni = min(istop1, istop2) - istart
jstart = max(jstart1, jstart2)
nj = min(jstop1, jstop2) - jstart
iskip = istart - istart1
jskip = jstart - jstart1
slice1 = (slice(jskip, jskip + nj), slice(iskip, iskip + ni))
iskip = istart - istart2
jskip = jstart - jstart2
slice2 = (slice(jskip, jskip + nj), slice(iskip, iskip + ni))
return slice1, slice2
DEG2RAD = np.pi / 180 # degree to radian conversion
RAD2DEG = 180 / np.pi # radian to degree conversion
R_EARTH = 6378815.0 # radius of the earth (m)
OMEGA = (
2.0 * np.pi / 86164.0
) # rotation rate of the earth (rad/s), 86164 is number of seconds in a sidereal day
[docs]
def coriolis(lat: npt.ArrayLike) -> np.ndarray:
"""Calculate Coriolis parameter f for the given latitude.
Args:
lat: latitude (°North)
Returns:
Coriolis parameter f (rad s-1)
"""
return (2.0 * OMEGA) * np.sin(DEG2RAD * np.asarray(lat, dtype=float))
[docs]
def centers_to_supergrid_1d(
source: npt.ArrayLike,
*,
dtype: Optional[npt.DTypeLike] = None,
edges: EdgeTreatment = EdgeTreatment.MISSING,
missing_value=np.nan,
) -> np.ndarray:
source = np.asarray(source)
assert source.ndim == 1, "source must be one-dimensional"
assert source.size > 1, "source must have at least 2 elements"
out_shape = (source.size * 2 + 1,)
out = np.empty_like(source, shape=out_shape, dtype=dtype)
out[1::2] = source
out[2:-2:2] = 0.5 * (source[1:] + source[:-1])
edges = EdgeTreatment(edges)
if edges == EdgeTreatment.PERIODIC:
# Reconstruct first and last interface by averaging very first and last values
out[0] = out[-1] = 0.5 * (out[1] + out[-2])
elif edges == EdgeTreatment.EXTRAPOLATE:
# Reconstruct first and last interface through linear extrapolation,
# using nearest interior difference
out[0] = 2 * out[1] - out[2]
out[-1] = 2 * out[-2] - out[-3]
elif edges == EdgeTreatment.EXTRAPOLATE_PERIODIC:
# Reconstruct first and last interface through linear extrapolation,
# using oppposite interior difference
out[0] = out[1] + (out[-3] - out[-2])
out[-1] = out[-2] + (out[2] - out[1])
elif edges == EdgeTreatment.CLAMP:
out[0] = out[1]
out[-1] = out[-2]
else:
out[0] = out[-1] = missing_value
return out
[docs]
def expand_2d(
source: npt.ArrayLike,
*,
dtype: Optional[npt.DTypeLike] = None,
edges_x: EdgeTreatment = EdgeTreatment.MISSING,
edges_y: EdgeTreatment = EdgeTreatment.MISSING,
missing_value=np.nan,
) -> np.ndarray:
# Create an array to hold data at centers (T points),
# with strips of size 1 on all sides to support interpolation to interfaces
source_shape = np.shape(source)
out_shape = source_shape[:-2] + (source_shape[-2] + 2, source_shape[-1] + 2)
out = np.empty_like(source, shape=out_shape, dtype=dtype)
out[..., 1:-1, 1:-1] = source
edges_x = EdgeTreatment(edges_x)
edges_y = EdgeTreatment(edges_y)
if edges_x == EdgeTreatment.EXTRAPOLATE:
out[..., 0] = 2 * out[..., 1] - out[..., 2]
out[..., -1] = 2 * out[..., -2] - out[..., -3]
elif edges_x == EdgeTreatment.EXTRAPOLATE_PERIODIC:
out[..., 0] = out[..., 1] + (out[..., -3] - out[..., -2])
out[..., -1] = out[..., -2] + (out[..., 2] - out[..., 1])
elif edges_x == EdgeTreatment.CLAMP:
out[..., 0] = out[..., 1]
out[..., -1] = out[..., -2]
elif edges_x == EdgeTreatment.PERIODIC:
out[..., 0] = out[..., -2]
out[..., -1] = out[..., 1]
elif edges_x == EdgeTreatment.MISSING:
out[..., 0] = missing_value
out[..., -1] = missing_value
if edges_y == EdgeTreatment.EXTRAPOLATE:
out[..., 0, :] = 2 * out[..., 1, :] - out[..., 2, :]
out[..., -1, :] = 2 * out[..., -2, :] - out[..., -3, :]
elif edges_y == EdgeTreatment.EXTRAPOLATE_PERIODIC:
out[..., 0, :] = out[..., 1, :] + (out[..., -3, :] - out[..., -2, :])
out[..., -1, :] = out[..., -2, :] + (out[..., 2, :] - out[..., 1, :])
elif edges_y == EdgeTreatment.CLAMP:
out[..., 0, :] = out[..., 1, :]
out[..., -1, :] = out[..., -2, :]
elif edges_y == EdgeTreatment.PERIODIC:
out[..., 0, :] = out[..., -2, :]
out[..., -1, :] = out[..., 1, :]
elif edges_y == EdgeTreatment.MISSING:
out[..., 0, :] = missing_value
out[..., -1, :] = missing_value
return out
[docs]
def centers_to_supergrid_2d(
source: npt.ArrayLike,
*,
dtype: Optional[npt.DTypeLike] = None,
edges_x: EdgeTreatment = EdgeTreatment.MISSING,
edges_y: EdgeTreatment = EdgeTreatment.MISSING,
missing_value=np.nan,
) -> np.ndarray:
source = np.asanyarray(source)
ny, nx = source.shape
source_ex = expand_2d(
source,
dtype=dtype,
edges_x=edges_x,
edges_y=edges_y,
missing_value=missing_value,
)
out_shape = source_ex.shape[:-2] + (ny * 2 + 1, nx * 2 + 1)
out = np.empty_like(source_ex, shape=out_shape)
if_ip_shape = (4,) + source_ex.shape[:-2] + (ny + 1, nx + 1)
data_if_ip = np.empty_like(source_ex, shape=if_ip_shape)
data_if_ip[0, ...] = source_ex[..., :-1, :-1]
data_if_ip[1, ...] = source_ex[..., 1:, :-1]
data_if_ip[2, ...] = source_ex[..., :-1, 1:]
data_if_ip[3, ...] = source_ex[..., 1:, 1:]
out[..., 1::2, 1::2] = source_ex[1:-1, 1:-1] # T points
out[..., ::2, ::2] = data_if_ip.mean(axis=0) # X points
data_if_ip[0, ..., :-1] = source_ex[..., :-1, 1:-1]
data_if_ip[1, ..., :-1] = source_ex[..., 1:, 1:-1]
out[..., ::2, 1::2] = data_if_ip[:2, ..., :-1].mean(axis=0) # V points
data_if_ip[0, ..., :-1, :] = source_ex[..., 1:-1, :-1]
data_if_ip[1, ..., :-1, :] = source_ex[..., 1:-1, 1:]
out[..., 1::2, ::2] = data_if_ip[:2, ..., :-1, :].mean(axis=0) # U points
return np.ma.filled(out, missing_value)
[docs]
def interfaces_to_supergrid_1d(
source: npt.ArrayLike,
*,
dtype: Optional[npt.DTypeLike] = None,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
source = np.asarray(source)
assert source.ndim == 1, "data must be one-dimensional"
assert source.size > 1, "data must have at least 2 elements"
if out is None:
out = np.empty_like(source, shape=(source.size * 2 - 1,), dtype=dtype)
out[0::2] = source
out[1::2] = 0.5 * (source[1:] + source[:-1])
return out
[docs]
def interfaces_to_supergrid_2d(
source: npt.ArrayLike,
*,
dtype: Optional[npt.DTypeLike] = None,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
source = np.asarray(source)
assert source.ndim == 2, "data must be two-dimensional"
assert (
source.shape[0] > 1 and source.shape[1] > 1
), "dimensions must have length >= 2"
if out is None:
out_shape = (source.shape[0] * 2 - 1, source.shape[1] * 2 - 1)
out = np.empty_like(source, shape=out_shape, dtype=dtype)
out[0::2, 0::2] = source
out[1::2, 0::2] = 0.5 * (source[:-1, :] + source[1:, :])
out[0::2, 1::2] = 0.5 * (source[:, :-1] + source[:, 1:])
out[1::2, 1::2] = 0.25 * (
source[:-1, :-1] + source[:-1, 1:] + source[1:, :-1] + source[1:, 1:]
)
return out
[docs]
def create_cartesian(
x: npt.ArrayLike,
y: npt.ArrayLike,
*,
interfaces: bool = False,
central_lon: Optional[float] = None,
central_lat: Optional[float] = None,
**kwargs,
) -> "Domain":
"""Create Cartesian domain from x and y coordinates.
Args:
x: array with x coordinates (m).
It can have shape ``(nx,)`` or ``(ny, nx)``.
Coordinates are interpreted to be positioned at cell interfaces if
if ``interfaces=True``, at cell centers otherwise.
y: array with y coordinates (m).
It can have shape ``(ny,)`` or ``(ny, nx)``.
Coordinates are interpreted to be positioned at cell interfaces if
if ``interfaces=True``, at cell centers otherwise.
interfaces: coordinates are given at cell interfaces rather than cell centers.
central_lon: longitude of the center of the domain (°East).
The center is ``[x.min() + x.max()]/2, [y.min() + y.max()]/2``.
central_lat: latitude of the center of the domain (°North).
The center is ``[x.min() + x.max()]/2, [y.min() + y.max()]/2``.
**kwargs: additional arguments passed to :class:`Domain`
"""
x = np.asarray(x)
y = np.asarray(y)
if y.ndim == 1:
y = y[:, np.newaxis]
if interfaces:
nx, ny = x.shape[-1] - 1, y.shape[0] - 1
else:
nx, ny = x.shape[-1], y.shape[0]
if central_lon is not None and central_lat is not None:
relx = x - 0.5 * (x.min() + x.max())
rely = y - 0.5 * (y.min() + y.max())
m_per_degree = DEG2RAD * R_EARTH
lats = rely / m_per_degree + central_lat
lons = relx / m_per_degree / np.cos(DEG2RAD * lats) + central_lon
lons, lats = np.broadcast_arrays(lons, lats)
kwargs.setdefault("lon", lons)
kwargs.setdefault("lat", lats)
return Domain(nx, ny, x=x, y=y, coordinate_type=CoordinateType.XY, **kwargs)
[docs]
def create_spherical(
lon: npt.ArrayLike, lat: npt.ArrayLike, *, interfaces: bool = False, **kwargs
) -> "Domain":
"""Create spherical domain from longitudes and latitudes.
Args:
lon: array with longitude coordinates (°East).
It can have shape ``(nx,)`` or ``(ny, nx)``.
Coordinates are interpreted to be positioned at cell interfaces if
if ``interfaces=True``, at cell centers otherwise.
lat: array with latitude coordinates (°North).
It can have shape ``(ny,)`` or ``(ny, nx)``.
Coordinates are interpreted to be positioned at cell interfaces if
if ``interfaces=True``, at cell centers otherwise.
interfaces: coordinates are given at cell interfaces rather than cell centers.
**kwargs: additional arguments passed to :class:`Domain`
"""
lon = np.asarray(lon)
lat = np.asarray(lat)
if lat.ndim == 1:
lat = lat[:, np.newaxis]
if interfaces:
nx, ny = lon.shape[-1] - 1, lat.shape[0] - 1
else:
nx, ny = lon.shape[-1], lat.shape[0]
return Domain(
nx, ny, lon=lon, lat=lat, coordinate_type=CoordinateType.LONLAT, **kwargs
)
[docs]
def create_spherical_at_resolution(
minlon: float,
maxlon: float,
minlat: float,
maxlat: float,
resolution: float,
**kwargs,
) -> "Domain":
"""Create spherical domain encompassing the specified longitude range and latitude
range and desired resolution in m.
Args:
minlon: minimum longitude (°East)
maxlon: maximum longitude (°East)
minlat: minimum latitude (°North)
maxlat: maximum latitude (°North)
resolution: maximum grid cell length and width (m)
**kwargs: additional arguments passed to :class:`Domain`
"""
if maxlon <= minlon:
raise Exception(
f"Maximum longitude {maxlon} must exceed minimum longitude {minlon}"
)
if maxlat <= minlat:
raise Exception(
f"Maximum latitude {maxlat} must exceed minimum latitude {minlat}"
)
if resolution <= 0.0:
raise Exception(f"Desired resolution must exceed 0, but is {resolution} m")
dlat = resolution / (DEG2RAD * R_EARTH)
minabslat = min(abs(minlat), abs(maxlat))
dlon = resolution / (DEG2RAD * R_EARTH) / np.cos(DEG2RAD * minabslat)
nx = int(np.ceil((maxlon - minlon) / dlon)) + 1
ny = int(np.ceil((maxlat - minlat) / dlat)) + 1
return create_spherical(
np.linspace(minlon, maxlon, nx),
np.linspace(minlat, maxlat, ny),
interfaces=True,
**kwargs,
)
[docs]
def apply_on_root_and_bcast(method):
@functools.wraps(method)
def wrapper(self: "Domain", *args, **kwargs):
if self.comm.rank == 0:
result = method(self, *args, **kwargs)
else:
result = None
return self.comm.bcast(result)
return wrapper
[docs]
def apply_only_on_root(method):
@functools.wraps(method)
def wrapper(self: "Domain", *args, **kwargs):
if self.comm.rank == 0:
return method(self, *args, **kwargs)
return wrapper
def _rotation(x: np.ndarray, y: np.ndarray) -> np.ndarray:
# For each point, draw lines to the nearest neighbor (1/2 a grid cell) on
# the left, right, top and bottom.
rot_left = np.arctan2(y[:, 1:-1] - y[:, :-2], x[:, 1:-1] - x[:, :-2])
rot_right = np.arctan2(y[:, 2:] - y[:, 1:-1], x[:, 2:] - x[:, 1:-1])
rot_bot = np.arctan2(y[1:-1, :] - y[:-2, :], x[1:-1, :] - x[:-2, :]) - 0.5 * np.pi
rot_top = np.arctan2(y[2:, :] - y[1:-1, :], x[2:, :] - x[1:-1, :]) - 0.5 * np.pi
x_dum = (
np.cos(rot_left[1:-1, :])
+ np.cos(rot_right[1:-1, :])
+ np.cos(rot_bot[:, 1:-1])
+ np.cos(rot_top[:, 1:-1])
)
y_dum = (
np.sin(rot_left[1:-1, :])
+ np.sin(rot_right[1:-1, :])
+ np.sin(rot_bot[:, 1:-1])
+ np.sin(rot_top[:, 1:-1])
)
return np.arctan2(y_dum, x_dum)
[docs]
def from_xarray(
ds: xr.Dataset, comm: Optional[parallel.MPI.Comm] = None, **kwargs
) -> "Domain":
"""Create domain from :class:`xarray.Dataset`.
This dataset must represent the arguments to :class:`Domain` by variables
(for 2D arrays) or attributes (for scalars) with the same names.
Such a dataset is produced by :meth:`Domain.to_xarray`.
Args:
ds: dataset with domain data
comm: MPI communicator that comprises all processes that should get
access to the domain
**kwargs: additional arguments passed to :class:`Domain`.
These override corresponding values in the dataset.
Returns:
domain created from the dataset
"""
comm = comm or parallel.mpi4py_autofree(parallel.MPI.COMM_WORLD.Dup())
kwargs.setdefault("coordinate_type", ds.attrs.get("coordinate_type"))
if "periodic_x" in ds.attrs:
kwargs.setdefault("periodic_x", bool(ds.attrs["periodic_x"]))
if "periodic_y" in ds.attrs:
kwargs.setdefault("periodic_y", bool(ds.attrs["periodic_y"]))
if comm.rank == 0:
for name in ("lon", "lat", "x", "y", "mask", "H", "z0", "f"):
if name in ds and name not in kwargs:
kwargs[name] = ds[name].values
assert "H" in ds, "Dataset must contain bathymetric depth H"
ny_sup, nx_sup = ds["H"].shape
return Domain((nx_sup - 1) // 2, (ny_sup - 1) // 2, comm=comm, **kwargs)
[docs]
class Domain:
def __init__(
self,
nx: int,
ny: int,
*,
lon: Optional[npt.ArrayLike] = None,
lat: Optional[npt.ArrayLike] = None,
x: Optional[npt.ArrayLike] = None,
y: Optional[npt.ArrayLike] = None,
coordinate_type: Optional[CoordinateType] = None,
mask: Optional[npt.ArrayLike] = 1,
H: Optional[npt.ArrayLike] = None,
z0: Optional[npt.ArrayLike] = 0.0,
f: Optional[npt.ArrayLike] = None,
periodic_x: bool = False,
periodic_y: bool = False,
comm: Optional[parallel.MPI.Comm] = None,
logger: Optional[logging.Logger] = None,
):
"""Create new domain.
Args:
nx: number of tracer points in x-direction
ny: number of tracer points in y-direction
lon: longitude (°East)
lat: latitude (°North)
x: x coordinate (m)
y: y coordinate (m)
coordinate_type: preferred coordinate type for plots and output
mask: initial mask (0: land, 1: water)
H: initial bathymetric depth. This is the distance between the bottom and
some arbitrary depth reference (m, positive if bottom lies below the
depth reference). Typically the depth reference is mean sea level.
Points with NaN or masked values will be marked as land in the mask.
z0: minimum hydrodynamic bottom roughness (m)
f: Coriolis parameter (rad s-1). If not provided, it will be calculated
from latitude. In that case argument `lat` must be provided.
periodic_x: use periodic boundary in x-direction (left == right)
periodic_y: use periodic boundary in y-direction (top == bottom)
comm: MPI communicator that comprises all processes that should get access
to the domain
logger: logger for diagnostic messages
"""
if nx <= 0:
raise Exception(f"Number of x points is {nx} but must be > 0")
if ny <= 0:
raise Exception(f"Number of y points is {ny} but must be > 0")
self.comm = comm or parallel.mpi4py_autofree(parallel.MPI.COMM_WORLD.Dup())
has_xy = self.comm.bcast(x is not None and y is not None)
has_lonlat = self.comm.bcast(lon is not None and lat is not None)
if not (has_xy or has_lonlat):
raise Exception("Either x and y, or lon and lat, must be provided")
has_f = self.comm.bcast(f is not None or lat is not None)
if not has_f:
raise Exception(
"Either lat of f must be provided to determine the Coriolis parameter."
)
if coordinate_type is None:
coordinate_type = CoordinateType.XY if has_xy else CoordinateType.LONLAT
elif isinstance(coordinate_type, str):
coordinate_type = CoordinateType[coordinate_type]
assert (coordinate_type == CoordinateType.XY and has_xy) or (
coordinate_type == CoordinateType.LONLAT
and has_lonlat
or (coordinate_type == CoordinateType.IJ and (has_xy or has_lonlat))
)
self.nx = nx
self.ny = ny
self.periodic_x = periodic_x
self.periodic_y = periodic_y
self.coordinate_type = coordinate_type
self.root_logger = logger or parallel.get_logger()
self.logger = self.root_logger.getChild("domain")
if self.comm.rank != 0:
self._x = self._y = self._lon = self._lat = None
self._mask = self._H = self._z0 = None
self._dx = self._dy = self._rotation = self._f = self._area = None
else:
self._x = self._map_array(x, edges=EdgeTreatment.EXTRAPOLATE)
self._y = self._map_array(y, edges=EdgeTreatment.EXTRAPOLATE)
if lon is not None and np.shape(lon) != (1 + ny * 2, 1 + nx * 2):
# Interpolate longitude in cos-sin space to handle periodic boundary
# condition, but skip this if longitude is already provided on supergrid
# (no interpolation needed) to improve accuracy of rotation tests.
lon = np.asarray(lon).astype(float, copy=False)
lon_rad = DEG2RAD * lon
coslon = np.cos(lon_rad)
sinlon = np.sin(lon_rad)
coslon = self._map_array(coslon, edges=EdgeTreatment.EXTRAPOLATE)
sinlon = self._map_array(sinlon, edges=EdgeTreatment.EXTRAPOLATE)
self._lon = np.arctan2(sinlon, coslon) * RAD2DEG
central_lon = 0.5 * (lon.min() + lon.max())
wrap_lon = central_lon - 180.0
self._lon = (self._lon - wrap_lon) % 360.0 + wrap_lon
else:
self._lon = self._map_array(lon)
self._lat = self._map_array(lat, edges=EdgeTreatment.EXTRAPOLATE)
self._f = self._map_array(f)
self.mask = mask
self.H = H
self.z0 = z0
kwargs_expand = {}
if self.periodic_x:
kwargs_expand["edges_x"] = EdgeTreatment.EXTRAPOLATE_PERIODIC
if self.periodic_y:
kwargs_expand["edges_y"] = EdgeTreatment.EXTRAPOLATE_PERIODIC
if has_xy:
# Expand x, y by 1 in each direction to calculate dx, dy, rotation
x_ex = expand_2d(self._x, **kwargs_expand)
y_ex = expand_2d(self._y, **kwargs_expand)
if has_lonlat:
# Expand lon, lat by 1 in each direction to calculate dx, dy, rotation
lon_rad_ex = DEG2RAD * expand_2d(self._lon, **kwargs_expand)
lat_rad_ex = DEG2RAD * expand_2d(self._lat, **kwargs_expand)
if has_xy:
dx_x = x_ex[1:-1, 2:] - x_ex[1:-1, :-2]
dy_x = y_ex[1:-1, 2:] - y_ex[1:-1, :-2]
dx_y = x_ex[2:, 1:-1] - x_ex[:-2, 1:-1]
dy_y = y_ex[2:, 1:-1] - y_ex[:-2, 1:-1]
scale = 1.0
else:
dlon_rad_x = lon_rad_ex[1:-1, 2:] - lon_rad_ex[1:-1, :-2]
dlat_rad_x = lat_rad_ex[1:-1, 2:] - lat_rad_ex[1:-1, :-2]
coslat_x = np.cos(0.5 * (lat_rad_ex[1:-1, 2:] + lat_rad_ex[1:-1, :-2]))
dx_x = coslat_x * np.sin(0.5 * dlon_rad_x)
dy_x = np.sin(0.5 * dlat_rad_x) * np.cos(0.5 * dlon_rad_x)
dlon_rad_y = lon_rad_ex[2:, 1:-1] - lon_rad_ex[:-2, 1:-1]
dlat_rad_y = lat_rad_ex[2:, 1:-1] - lat_rad_ex[:-2, 1:-1]
coslat_y = np.cos(0.5 * (lat_rad_ex[2:, 1:-1] + lat_rad_ex[:-2, 1:-1]))
dx_y = coslat_y * np.sin(0.5 * dlon_rad_y)
dy_y = np.sin(0.5 * dlat_rad_y) * np.cos(0.5 * dlon_rad_y)
scale = R_EARTH * 2.0
self._dx = scale * np.hypot(dx_x, dy_x)
self._dy = scale * np.hypot(dx_y, dy_y)
if has_lonlat and not (
(self._lat == self._lat[0, 0]).any()
and (self._lon == self._lon[0, 0]).any()
and has_xy
):
# Proper rotation with respect to true North
rotation = _rotation(lon_rad_ex, lat_rad_ex)
else:
# Rotation with respect to y-axis - assumes y-axis always points to
# true North (can be valid only for infinitesimally small domain)
rotation = _rotation(x_ex, y_ex)
self._rotation = rotation if rotation[1:-1, 1:-1].any() else None
self._area = self._dx * self._dy
self.open_boundaries = open_boundaries.GlobalOpenBoundaryCollection(
nx, ny, self.logger.getChild("open_boundaries"), **self._tcoords()
)
self.rivers = rivers.GlobalRiverCollection(
nx, ny, coordinate_type, self.logger.getChild("rivers")
)
self.default_output_transforms = []
self.extra_output_coordinates = []
self.input_grid_mappers = []
def _tcoords(self, **kwargs: np.ndarray) -> Mapping[str, np.ndarray]:
"""Coordinates for T points (cell centers)"""
if self._x is not None:
kwargs.setdefault("x", self._x)
if self._y is not None:
kwargs.setdefault("y", self._y)
if self._lon is not None:
kwargs.setdefault("lon", self._lon)
if self._lat is not None:
kwargs.setdefault("lat", self._lat)
return {n: c[1::2, 1::2] for n, c in kwargs.items()}
def _map_array(
self,
values: Optional[npt.ArrayLike],
*,
dtype: npt.DTypeLike = float,
edges: EdgeTreatment = EdgeTreatment.MISSING,
missing_value=np.nan,
) -> Optional[np.ndarray]:
if self.comm.rank != 0 or values is None:
return None
source_shape = np.shape(values)
source_shape = (1,) * (2 - len(source_shape)) + source_shape # broadcast
target_shape = (self.ny * 2 + 1, self.nx * 2 + 1)
def can_cast(*target_shape: int) -> bool:
assert len(target_shape) == len(source_shape)
return all((l == 1 or l == lr) for l, lr in zip(source_shape, target_shape))
if source_shape[0] == 1 and source_shape[1] == 1:
# scalar value
mapped_values = np.array(values, dtype=dtype)
elif can_cast(self.ny, self.nx):
# values provided at cell centers
edges_x = edges_y = edges
if self.periodic_x and edges_x != EdgeTreatment.EXTRAPOLATE:
edges_x = EdgeTreatment.PERIODIC
if self.periodic_y and edges_y != EdgeTreatment.EXTRAPOLATE:
edges_y = EdgeTreatment.PERIODIC
if source_shape[0] == 1:
values_sup = centers_to_supergrid_1d(
np.ravel(values),
edges=edges_x,
missing_value=missing_value,
dtype=dtype,
)
mapped_values = values_sup[np.newaxis, :]
elif source_shape[1] == 1:
values_sup = centers_to_supergrid_1d(
np.ravel(values),
edges=edges_y,
missing_value=missing_value,
dtype=dtype,
)
mapped_values = values_sup[:, np.newaxis]
else:
mapped_values = centers_to_supergrid_2d(
values,
edges_x=edges_x,
edges_y=edges_y,
missing_value=missing_value,
dtype=dtype,
)
elif can_cast(self.ny + 1, self.nx + 1):
# values provided at cell corners
if source_shape[0] == 1:
values_sup = interfaces_to_supergrid_1d(np.ravel(values), dtype=dtype)
mapped_values = values_sup[np.newaxis, :]
elif source_shape[1] == 1:
values_sup = interfaces_to_supergrid_1d(np.ravel(values), dtype=dtype)
mapped_values = values_sup[:, np.newaxis]
else:
mapped_values = interfaces_to_supergrid_2d(values)
else:
# values provided on supergrid
assert can_cast(
*target_shape
), f"Cannot map array with shape {values.shape} to supergrid with shape {target_shape}"
mapped_values = np.ma.MaskedArray(values, dtype=dtype).filled(missing_value)
if mapped_values.shape != target_shape:
mapped_values = np.broadcast_to(mapped_values, target_shape)
return mapped_values
@property
def x(self) -> Optional[np.ndarray]:
"""x coordinate (m)
It is defined on the supergrid and thus has shape ``(ny*2+1, nx*2+1)``.
Cell centers (T points) are at ``[1::2, 1::2]``, interfaces at ``[1::2, ::2]``
(U points) and ``[::2, 1::2]`` (V points), corners (X points) at ``[::2, ::2]``.
This attribute is None on non-root MPI nodes or if ``x`` was not provided
at domain creation (for instance, for spherical domains).
"""
return self._x
@property
def y(self) -> Optional[np.ndarray]:
"""y coordinate (m)
It is defined on the supergrid and thus has shape ``(ny*2+1, nx*2+1)``.
Cell centers (T points) are at ``[1::2, 1::2]``, interfaces at ``[1::2, ::2]``
(U points) and ``[::2, 1::2]`` (V points), corners (X points) at ``[::2, ::2]``.
This attribute is None on non-root MPI nodes or if ``y`` was not provided
at domain creation (for instance, for spherical domains).
"""
return self._y
@property
def lon(self) -> Optional[np.ndarray]:
"""longitude (°East)
It is defined on the supergrid and thus has shape ``(ny*2+1, nx*2+1)``.
Cell centers (T points) are at ``[1::2, 1::2]``, interfaces at ``[1::2, ::2]``
(U points) and ``[::2, 1::2]`` (V points), corners (X points) at ``[::2, ::2]``.
This attribute is None on non-root MPI nodes or if ``lon`` was not provided
at domain creation (for instance, for Cartesian domains).
"""
return self._lon
@property
def lat(self) -> Optional[np.ndarray]:
"""latitude (°North)
It is defined on the supergrid and thus has shape ``(ny*2+1, nx*2+1)``.
Cell centers (T points) are at ``[1::2, 1::2]``, interfaces at ``[1::2, ::2]``
(U points) and ``[::2, 1::2]`` (V points), corners (X points) at ``[::2, ::2]``.
This attribute is None on non-root MPI nodes or if ``lat`` was not provided
at domain creation (for instance, for Cartesian domains with prescribed
Coriolis parameter).
"""
return self._lat
@property
def f(self) -> Optional[np.ndarray]:
"""Coriolis parameter (rad s-1)
It is defined on the supergrid and thus has shape ``(ny*2+1, nx*2+1)``.
Cell centers (T points) are at ``[1::2, 1::2]``, interfaces at ``[1::2, ::2]``
(U points) and ``[::2, 1::2]`` (V points), corners (X points) at ``[::2, ::2]``.
This attribute is None on non-root MPI nodes or if ``f`` was not provided
at domain creation. In the latter case, it will be calculated from :attr:`lat`.
"""
return self._f
@property
def dx(self) -> Optional[np.ndarray]:
"""grid cell length in x-direction (m)
It is defined on the supergrid and thus has shape ``(ny*2+1, nx*2+1)``.
Cell centers (T points) are at ``[1::2, 1::2]``, interfaces at ``[1::2, ::2]``
(U points) and ``[::2, 1::2]`` (V points), corners (X points) at ``[::2, ::2]``.
This attribute is None on non-root MPI nodes.
"""
return self._dx
@property
def dy(self) -> Optional[np.ndarray]:
"""grid cell length in y-direction (m)
It is defined on the supergrid and thus has shape ``(ny*2+1, nx*2+1)``.
Cell centers (T points) are at ``[1::2, 1::2]``, interfaces at ``[1::2, ::2]``
(U points) and ``[::2, 1::2]`` (V points), corners (X points) at ``[::2, ::2]``.
This attribute is None on non-root MPI nodes.
"""
return self._dy
@property
def rotation(self) -> Optional[np.ndarray]:
"""grid rotation with respect to true North (rad)
It is defined on the supergrid and thus has shape ``(ny*2+1, nx*2+1)``.
Cell centers (T points) are at ``[1::2, 1::2]``, interfaces at ``[1::2, ::2]``
(U points) and ``[::2, 1::2]`` (V points), corners (X points) at ``[::2, ::2]``.
This attribute is None on non-root MPI nodes, or if the y-axis points
to true North in every point (i.e, rotation is zero everywhere).
"""
return self._rotation
@property
def area(self) -> Optional[np.ndarray]:
"""grid cell area (m²)
It is defined on the supergrid and thus has shape ``(ny*2+1, nx*2+1)``.
Cell centers (T points) are at ``[1::2, 1::2]``, interfaces at ``[1::2, ::2]``
(U points) and ``[::2, 1::2]`` (V points), corners (X points) at ``[::2, ::2]``.
This attribute is None on non-root MPI nodes.
"""
return self._area
@property
def mask(self) -> Optional[np.ndarray]:
"""land-sea mask (0: land, 1: water)
It is defined on the supergrid and thus has shape ``(ny*2+1, nx*2+1)``.
Cell centers (T points) are at ``[1::2, 1::2]``, interfaces at ``[1::2, ::2]``
(U points) and ``[::2, 1::2]`` (V points), corners (X points) at ``[::2, ::2]``.
It can be changed by assigning directly to the attribute or to slices of it.
If assigning directly, values defined on the T grid ``(ny x nx)``, the X grid
``(ny+1 x nx+1)``, or the supergrid ``(ny*2+1 x nx*2+1)`` are accepted,
as are scalars. Assigned values are then interpolated to the supergrid.
This attribute is None on non-root MPI nodes.
"""
if not self._mask.flags.writeable:
self._mask = self._mask.copy()
return self._mask
@mask.setter
def mask(self, values: npt.ArrayLike):
self._mask = self._map_array(values, missing_value=0, dtype=int)
@property
def H(self) -> Optional[np.ndarray]:
"""bathymetric depth (m)
This is the distance between the bottom and some arbitrary depth
reference (m, positive if bottom lies below the depth reference).
Typically the depth reference is mean sea level.
It is defined on the supergrid and thus has shape ``(ny*2+1, nx*2+1)``.
Cell centers (T points) are at ``[1::2, 1::2]``, interfaces at ``[1::2, ::2]``
(U points) and ``[::2, 1::2]`` (V points), corners (X points) at ``[::2, ::2]``.
It can be changed by assigning directly to the attribute or to slices of it.
If assigning directly, values defined on the T grid ``(ny x nx)``, the X grid
``(ny+1 x nx+1)``, or the supergrid ``(ny*2+1 x nx*2+1)`` are accepted,
as are scalars. Assigned values are then interpolated to the supergrid.
When assigning directly, masked or NaN values will be marked as land in
the mask.
This attribute is None on non-root MPI nodes or if H was not provided yet.
"""
if self._H is not None and not self._H.flags.writeable:
self._H = self._H.copy()
return self._H
@H.setter
def H(self, values: Optional[npt.ArrayLike]):
if self.comm.rank == 0 and values is not None:
values = np.ma.masked_invalid(values)
self._H = self._map_array(values, edges=EdgeTreatment.CLAMP)
self._mask = self._mask & np.isfinite(self._H)
else:
self._H = None
@property
def z0(self) -> Optional[np.ndarray]:
"""minimum hydrodynamic bottom roughness (m)
It is defined on the supergrid and thus has shape ``(ny*2+1, nx*2+1)``.
Cell centers (T points) are at ``[1::2, 1::2]``, interfaces at ``[1::2, ::2]``
(U points) and ``[::2, 1::2]`` (V points), corners (X points) at ``[::2, ::2]``.
It can be changed by assigning directly to the attribute or to slices of it.
If assigning directly, values defined on the T grid ``(ny x nx)``, the X grid
``(ny+1 x nx+1)``, or the supergrid ``(ny*2+1 x nx*2+1)`` are accepted,
as are scalars. Assigned values are then interpolated to the supergrid.
This attribute is None on non-root MPI nodes
"""
if self._z0 is not None and not self._z0.flags.writeable:
self._z0 = self._z0.copy()
return self._z0
@z0.setter
def z0(self, values: npt.ArrayLike):
self._z0 = self._map_array(values)
[docs]
def create_tiling(self) -> parallel.Tiling:
mask = None if self.comm.rank != 0 else self._mask[1::2, 1::2]
mask = self.comm.bcast(mask)
return parallel.Tiling.autodetect(
mask,
periodic_x=self.periodic_x,
periodic_y=self.periodic_y,
comm=self.comm,
logger=self.logger.getChild("subdomain_decomposition"),
)
[docs]
def create_grids(
self,
nz: Optional[int],
halox: int,
haloy: int,
fields: Optional[Mapping[str, core.Array]] = None,
tiling: Optional[parallel.Tiling] = None,
input_manager=None,
velocity_grids: int = 0,
t_postfix: str = "",
) -> core.Grid:
final_mask = self.get_final_mask()
if self.comm.rank == 0:
if self._H is None:
raise Exception("Water depth at rest (H) has not been provided")
# Map river coordinates to global grid indices
self.rivers.map_to_grid(core.Locator(**self._tcoords(mask=final_mask)))
if tiling is None:
tiling = self.create_tiling()
elif tiling.nx_glob is None:
tiling.set_extent(self.nx, self.ny)
# Map grid attributes to supergrid arrays
domain_vars = dict(
x=self._x,
y=self._y,
lon=self._lon,
lat=self._lat,
cor=self._f,
dx=self._dx,
dy=self._dy,
rotation=self._rotation,
area=self._area,
mask=final_mask,
H=self._H,
z0b_min=self._z0,
)
# NB the fields argument cannot have a default of {}, as that causes
# that global dictionary to be shared among all calls to create_grids.
if fields is None:
fields = {}
def create_grid(
postfix: str,
ioffset: int,
joffset: int,
*,
overlap: int = 0,
**kwargs,
) -> core.Grid:
grid = core.Grid(
tiling.nx_sub + overlap,
tiling.ny_sub + overlap,
nz,
halox=halox,
haloy=haloy,
postfix=postfix,
fields=fields,
tiling=tiling,
ioffset=ioffset,
joffset=joffset,
overlap=overlap,
**kwargs,
)
self._populate_grid(grid, domain_vars)
grid.input_manager = input_manager
grid.default_output_transforms = self.default_output_transforms
grid.extra_output_coordinates = self.extra_output_coordinates
grid.input_grid_mappers = [m(grid=grid) for m in self.input_grid_mappers]
return grid
U = V = X = UU = UV = VU = VV = None
if velocity_grids > 1:
UU = create_grid("_uu_adv", 3, 1)
UV = create_grid("_uv_adv", 2, 2)
VU = create_grid("_vu_adv", 2, 2)
VV = create_grid("_vv_adv", 1, 3)
if velocity_grids > 0:
U = create_grid("u", 2, 1, ugrid=UU, vgrid=UV)
V = create_grid("v", 1, 2, ugrid=VU, vgrid=VV)
X = create_grid("x", 0, 0, overlap=1, istart=0, jstart=0)
T = create_grid(t_postfix, 1, 1, ugrid=U, vgrid=V, xgrid=X)
if velocity_grids > 0:
T.infer_water_contact()
T.close_flux_interfaces()
if velocity_grids > 1:
U.close_flux_interfaces()
V.close_flux_interfaces()
# No transport between velocity points along an open boundary (just outside)
# This is done to state that no valid values (in e.g. h and D) are required
# in these points.
U_mirror_ext = U.mask.all_values == CellType.MIRROR_EXT
UV.mask.all_values[:-1, :][
U_mirror_ext[:-1, :] & U_mirror_ext[1:, :]
] = CellType.UNRESOLVED
V_mirror_ext = V.mask.all_values == CellType.MIRROR_EXT
VU.mask.all_values[:, :-1][
V_mirror_ext[:, :-1] & V_mirror_ext[:, 1:]
] = CellType.UNRESOLVED
T.freeze()
open_boundaries.LocalOpenBoundaryCollection(
self.open_boundaries, T, logger=self.logger
)
T.rivers = self.rivers.initialize(T)
return T
def _populate_grid(
self, grid: core.Grid, domain_vars: Mapping[str, Optional[np.ndarray]]
):
edges_x = EdgeTreatment.PERIODIC if self.periodic_x else EdgeTreatment.MISSING
edges_y = EdgeTreatment.PERIODIC if self.periodic_y else EdgeTreatment.MISSING
ioffset = grid.ioffset % 2
joffset = grid.joffset % 2
istart_glob = (ioffset - grid.ioffset) // 2
jstart_glob = (joffset - grid.joffset) // 2
assert istart_glob <= 0, "Supergrid starts after first x"
assert jstart_glob <= 0, "Supergrid starts after first y"
nx_glob = self.nx + grid.overlap
ny_glob = self.ny + grid.overlap
def _transfer_variable(source: np.ndarray, target: core.Array):
sendbuf = None
if grid.tiling.rank == 0:
# We are on the root node that has the global values
if grid.joffset > 1 or grid.ioffset > 1:
# UU or VV grid that needs one more strip of 1 cell
# beyond the end of the supergrid. By default, that is
# left at missing_value, but that is inappropriate for
# periodic boundaries that need mirroring. Therefore we
# expand the grid properly, any mirroring included.
source = expand_2d(
source,
edges_x=edges_x,
edges_y=edges_y,
missing_value=target.fill_value,
)[1:, 1:]
global_values = source[joffset::2, ioffset::2]
istop_glob = istart_glob + global_values.shape[-1]
jstop_glob = jstart_glob + global_values.shape[-2]
assert istop_glob >= nx_glob, "Supergrid stops before last x"
assert jstop_glob >= ny_glob, "Supergrid stops before last y"
sendbuf = grid.tiling._get_work_array(
(grid.tiling.comm.size,) + target.all_values.shape,
target.dtype,
target.fill_value,
)
sendbuf.fill(target.fill_value)
for irow, icol, rank in parallel._iterate_rankmap(grid.tiling.map):
if rank < 0:
continue
jslice, islice = grid.tiling.subdomain2rawslices(
irow,
icol,
halox_sub=grid.halox,
haloy_sub=grid.haloy,
share=grid.overlap,
)
slice_glob, slice_loc = _get_rectangle_overlap(
istart_glob,
istop_glob,
jstart_glob,
jstop_glob,
islice.start,
islice.stop,
jslice.start,
jslice.stop,
)
sendbuf[(rank,) + slice_loc] = global_values[slice_glob]
# Keep a pointer to the full global field
# This will be used preferentially for full-domain output
target.attrs["_global_values"] = global_values[
-jstart_glob : ny_glob - jstart_glob,
-istart_glob : nx_glob - istart_glob,
]
assert target.attrs["_global_values"].shape == (ny_glob, nx_glob)
else:
target.attrs["_global_values"] = None
grid.tiling.comm.Scatter(sendbuf, target.all_values)
if self.periodic_x or self.periodic_y:
target.update_halos()
retrieved_from_domain = set()
for name, source in domain_vars.items():
if grid.tiling.comm.bcast(source is not None):
target = grid.create_array(name)
if target is not None:
_transfer_variable(source, target)
retrieved_from_domain.add(name)
# If the Coriolis parameter was not set explicitly at domain level,
# calculate it from latitude
if "cor" not in retrieved_from_domain:
grid.create_array("cor").all_values = coriolis(grid._lat.all_values)
# Set default horizontal coordinates (e.g., for output and online plotting)
# based on the coordinate type set at domain level.
if self.coordinate_type == CoordinateType.XY:
grid.horizontal_coordinates += [grid.x, grid.y]
elif self.coordinate_type == CoordinateType.LONLAT:
grid.horizontal_coordinates += [grid.lon, grid.lat]
[docs]
@apply_on_root_and_bcast
def cfl_check(
self, z: float = 0.0, return_location: bool = False
) -> Union[float, tuple[float, int, int, float]]:
"""Determine maximum time step (s) for depth-integrated equations
Args:
z: maximum surface elevation (m)
return_location: whether to also return the location
and depth that determined the maximum step
Note: this returns global indices for the T grid, not the supergrid
"""
dx = self._dx[1::2, 1::2]
dy = self._dy[1::2, 1::2]
D = self._H[1::2, 1::2] + z
mask = (self._mask[1::2, 1::2] > 0) & (D > 0.0)
denom2 = (2.0 * GRAVITY) * D * (dx**2 + dy**2)
maxdts = dx * dy / np.sqrt(denom2, where=mask, out=np.ones_like(D))
maxdts[~mask] = np.inf
maxdt = maxdts.min()
if return_location:
j, i = np.unravel_index(np.argmin(maxdts), maxdts.shape)
return (maxdt, i, j, self._H[1 + 2 * j, 1 + 2 * i])
return maxdt
@property
def maxdt(self) -> float:
"""Maximum time step (s) for depth-integrated equations"""
return self.cfl_check()
[docs]
def get_rx0(
self, zmin: float = 0.0, Dmin: float = 0.0
) -> tuple[np.ndarray, np.ndarray]:
"""Calculates the slope factor ``rx0`` as defined in
https://doi.org/10.1016/j.ocemod.2009.03.009
At interfaces of shallow points (the depth of at least one wet neighbor
being less than ``Dmin``), the surface elevation is clipped to the
minimum value that still keeps both points wet.
Args:
zmin: minimum surface elevation (m)
Dmin: minimum depth (m)
Returns:
a tuple with slope factors at U and V points,
positioned at ``[1::2, 2:-2:2]`` and ``[2:-2:2, 1::2]``, respectively
"""
H = self._H[1::2, 1::2]
zmin_loc = np.maximum(-H + Dmin, zmin)
zmin_u = np.maximum(zmin_loc[:, 1:], zmin_loc[:, :-1])
zmin_v = np.maximum(zmin_loc[1:, :], zmin_loc[:-1, :])
rx0_u = np.abs(H[:, 1:] - H[:, :-1]) / (H[:, 1:] + H[:, :-1] + 2 * zmin_u)
rx0_v = np.abs(H[1:, :] - H[:-1, :]) / (H[1:, :] + H[:-1, :] + 2 * zmin_v)
tmask = self._mask[1::2, 1::2] == 0
rx0_u[tmask[:, 1:] | tmask[:, :-1]] = 0.0
rx0_v[tmask[1:, :] | tmask[:-1, :]] = 0.0
return rx0_u, rx0_v
@property
@apply_on_root_and_bcast
def max_rx0(self) -> float:
"""Maximum slope factor rx0 as defined in https://doi.org/10.1016/j.ocemod.2009.03.009"""
rx0_u, rx0_v = self.get_rx0()
return max(rx0_u.max(initial=0.0), rx0_v.max(initial=0.0))
[docs]
@apply_only_on_root
def smooth(self, rx0: float = 0.2) -> Optional[np.ndarray]:
"""Smooth bathymetry by reducing slope factor to specified maximum.
The criterion to be satisfied at every interface between wet points is:
abs(H1 - H2) - rx0 * H1 - rx0 * H2 < 0
Handling abs as described at https://lpsolve.sourceforge.net/5.1/absolute.htm,
we obtain the two criteria:
( 1-rx0) * H1 + (-1-rx0) * H2 < 0
(-1-rx0) * H1 + ( 1-rx0) * H2 < 0
Splitting into original bathymetry H and correction H', and rearranging:
( 1-rx0)H1' + (-1-rx0)H2' < -( 1-rx0)H1 - (-1-rx0)H2
(-1-rx0)H1' + ( 1-rx0)H2' < -(-1-rx0)H1 - ( 1-rx0)H2
H1' and H2' are the corrections to the bathymetry, estimated by linear programming
as described in https://doi.org/10.1016/j.ocemod.2009.03.009
Args:
rx0: maximum slope factor
Returns:
bathymetry corrections (m).
These are defined at T points, that is, at ``[1::2, 1::2]``
"""
rx0_u, rx0_v = self.get_rx0()
current_max_rx0 = max(rx0_u.max(initial=0.0), rx0_v.max(initial=0.0))
if current_max_rx0 <= rx0:
return np.zeros_like(self.H[1::2, 1::2])
import scipy.optimize
import scipy.sparse
twet = self._mask[1::2, 1::2] != 0
nwet = twet.sum()
H = self.H[1::2, 1::2][twet]
iwet = np.full(twet.shape, -1, dtype=np.intp)
iwet[twet] = np.arange(nwet)
uwet = twet[:, 1:] & twet[:, :-1]
vwet = twet[1:, :] & twet[:-1, :]
nu = uwet.sum()
nv = vwet.sum()
n = nu + nv
nconstraints = n * 2 + nwet * 2
A_values = np.empty((nconstraints, 2))
A_i = np.empty_like(A_values, dtype=np.intp)
A_j = np.empty_like(A_values, dtype=np.intp)
b = np.zeros(nconstraints)
# Constraints on raw [not absolute] bathymetry corrections
A_values[:n, 0] = 1.0 - rx0
A_values[:n, 1] = -1.0 - rx0
A_values[n : 2 * n, 0] = -1.0 - rx0
A_values[n : 2 * n, 1] = 1.0 - rx0
A_i[:, :] = np.arange(nconstraints)[:, np.newaxis]
A_j[:nu, 0] = iwet[:, :-1][uwet]
A_j[:nu, 1] = iwet[:, 1:][uwet]
A_j[nu:n, 0] = iwet[:-1, :][vwet]
A_j[nu:n, 1] = iwet[1:, :][vwet]
A_j[n : 2 * n, :] = A_j[:n, :]
b[: 2 * n] = -(A_values[: 2 * n] * H[A_j[: 2 * n]]).sum(axis=1)
# Trick to minimize sum of absolute corrections
# We introduce one new variable M per wet point (correction),
# constrained by M >= H' and M >= -H'
# The objective is to minimize sum(M)
A_j[2 * n : 2 * n + nwet, 0] = np.arange(nwet)
A_j[2 * n : 2 * n + nwet, 1] = np.arange(nwet, 2 * nwet)
A_j[2 * n + nwet :, :] = A_j[2 * n : 2 * n + nwet, :]
A_values[2 * n : 2 * n + nwet, 0] = 1.0
A_values[2 * n + nwet :, 0] = -1.0
A_values[2 * n :, 1] = -1.0
c = np.zeros(2 * nwet)
c[nwet:] = 1.0
A = scipy.sparse.coo_matrix((A_values.ravel(), (A_i.ravel(), A_j.ravel())))
res = scipy.optimize.linprog(c=c, A_ub=A, b_ub=b, bounds=(None, None))
if not res.success:
raise Exception(
f"Failed to optimize bathymetry for rx0<={rx0}: {res.message}"
)
Hcor = np.zeros_like(self.H[1::2, 1::2])
Hcor[twet] = res.x[:nwet]
self.H = self.H[1::2, 1::2] + Hcor
return Hcor
[docs]
def get_final_mask(self) -> Optional[np.ndarray]:
"""Infer masks for U, V, X points from T point mask.
This closes interfaces and corners that neighbor one or more dry points.
Additionally, it sets the mask (values 2,3,4) within and alongside open
boundaries.
"""
if self._mask is None:
return None
mask = np.where(self._mask, 1, 0)
tmask = mask[1::2, 1::2]
umask = mask[1::2, ::2]
vmask = mask[::2, 1::2]
xmask = mask[::2, ::2]
# Expand T mask by one row and column in each direction,
# respecting periodic boundaries
edges_x = EdgeTreatment.PERIODIC if self.periodic_x else EdgeTreatment.MISSING
edges_y = EdgeTreatment.PERIODIC if self.periodic_y else EdgeTreatment.MISSING
tmask_ex = expand_2d(tmask, edges_x=edges_x, edges_y=edges_y, missing_value=0)
# Mask U,V,X points unless all their T neighbors are valid
umask[(tmask_ex[1:-1, 1:] == 0) | (tmask_ex[1:-1, :-1] == 0)] = 0
vmask[(tmask_ex[1:, 1:-1] == 0) | (tmask_ex[:-1, 1:-1] == 0)] = 0
xmask[
(tmask_ex[1:, 1:] == 0)
| (tmask_ex[:-1, 1:] == 0)
| (tmask_ex[1:, :-1] == 0)
| (tmask_ex[:-1, :-1] == 0)
] = 0
self.open_boundaries.adjust_mask(mask)
return mask
[docs]
@apply_only_on_root
def mask_shallow(self, minimum_depth: float):
"""Mask all points shallower less the specified value.
Args:
minimum_depth: minimum bathmetric depth :attr:`H`; points that are
shallower will be masked
"""
self.mask[self._H < minimum_depth] = 0
[docs]
@apply_only_on_root
def mask_subbasins(self, nkeep: int = 1):
"""Identify all separate basins (each a collection of unmasked cell
centers connected via top/bottom/left/right interfaces), and mask all
but the largest one(s).
Args:
nkeep: number of basins to keep
"""
import skimage.segmentation
tmask = np.minimum(self.mask[1::2, 1::2], 1)
next_id = -1
basin2size = {}
while True:
indices = (tmask == 1).nonzero()
if indices[0].size == 0:
break
seed_point = (indices[0][0], indices[1][0])
skimage.segmentation.flood_fill(
tmask, seed_point, next_id, connectivity=1, in_place=True
)
basin2size[next_id] = (tmask == next_id).sum()
next_id -= 1
ordered = sorted(basin2size.keys(), key=lambda x: basin2size[x], reverse=True)
for v in ordered[nkeep:]:
tmask[tmask == v] = 0
self.mask = np.where(tmask == 0, 0, self.mask[1::2, 1::2])
[docs]
@apply_only_on_root
def limit_velocity_depth(self, critical_depth: float = np.inf):
"""Decrease bathymetric depth of velocity (U, V) points to the minimum of the
bathymetric depth of both neighboring T points, wherever one of these two
points is shallower than the specified critical depth.
Args:
critical_depth: neighbor depth at which the limiting starts. If either
neighbor (T grid) is shallower than this value, the depth of velocity
point (U or V grid) is restricted.
"""
# NB self._H may be read-only; self.H is always writeable
# Expand T depths by one row and column in each direction,
# respecting periodic boundaries
edges_x = EdgeTreatment.PERIODIC if self.periodic_x else EdgeTreatment.MISSING
edges_y = EdgeTreatment.PERIODIC if self.periodic_y else EdgeTreatment.MISSING
tdepth = expand_2d(self._H[1::2, 1::2], edges_x=edges_x, edges_y=edges_y)
Vmin = np.minimum(tdepth[1:, 1:-1], tdepth[:-1, 1:-1])
Vsel = (Vmin <= critical_depth) & np.isfinite(Vmin)
np.putmask(self.H[::2, 1::2], Vsel, Vmin)
Umin = np.minimum(tdepth[1:-1, 1:], tdepth[1:-1, :-1])
Usel = (Umin <= critical_depth) & np.isfinite(Umin)
np.putmask(self.H[1::2, ::2], Usel, Umin)
self.logger.info(
f"limit_velocity_depth has decreased depth in {Usel.sum()} U points"
f" ({Usel.sum(where=self._mask[1::2, ::2] > 0)} currently unmasked),"
f" {Vsel.sum()} V points ({Vsel.sum(where=self._mask[::2, 1::2] > 0)}"
" currently unmasked)."
)
[docs]
@apply_only_on_root
def mask_rectangle(
self,
xmin: Optional[float] = None,
xmax: Optional[float] = None,
ymin: Optional[float] = None,
ymax: Optional[float] = None,
mask_value: int = 0,
coordinate_type: Optional[CoordinateType] = None,
):
"""Mask all points that fall within the specified rectangle.
Args:
xmin: lower native x coordinate of the rectangle to mask
(default: left boundary of the domain)
xmax: upper native x coordinate of the rectangle to mask
(default: right boundary of the domain)
ymin: lower native y coordinate of the rectangle to mask
(default: bottom boundary of the domain)
ymax: upper native y coordinate of the rectangle to mask
(default: top boundary of the domain)
Coordinates will be interpreted as longitude, latitude if the domain is
configured as spherical; otherwise they will be interpreted as Cartesian
x and y (m).
"""
coordinate_type = coordinate_type or self.coordinate_type
if coordinate_type == CoordinateType.LONLAT:
x, y = (self._lon, self._lat)
elif coordinate_type == CoordinateType.XY:
x, y = (self._x, self._y)
else:
x = np.linspace(-0.5, self.nx + 0.5, 1 + 2 * self.nx)
y = np.linspace(-0.5, self.ny + 0.5, 1 + 2 * self.ny)
x, y = np.broadcast_arrays(x[np.newaxis, :], y[:, np.newaxis])
selected = True
if xmin is not None:
selected &= x >= xmin
if xmax is not None:
selected &= x <= xmax
if ymin is not None:
selected &= y >= ymin
if ymax is not None:
selected &= y <= ymax
self.mask[selected] = mask_value
[docs]
@apply_only_on_root
def mask_indices(
self, istart: int, istop: int, jstart: int, jstop: int, mask_value: int = 0
):
"""Mask all points that fall within the specified rectangle.
Indices must be provided for the T grid, and thus range between 0 and `nx`
for i, and between 0 and `ny` for j.
Args:
istart: lower x index (first that is included)
istop: upper x index (first that is EXcluded)
jstart: lower y index (first that is included)
jstop: upper y index (first that is EXcluded)
"""
istart_T = 1 + 2 * istart
istop_T = 1 + 2 * istop
jstart_T = 1 + 2 * jstart
jstop_T = 1 + 2 * jstop
self.mask[jstart_T:jstop_T, istart_T:istop_T] = mask_value
def _transform(self, nx: int, ny: int, tf, **kwargs) -> "Domain":
kwargs.setdefault("coordinate_type", self.coordinate_type)
kwargs.setdefault("comm", self.comm)
kwargs.setdefault("logger", self.root_logger)
if self.comm.rank == 0:
kwargs.update(
lon=tf(self._lon),
lat=tf(self._lat),
x=tf(self._x),
y=tf(self._y),
mask=np.minimum(tf(self._mask), 1),
H=tf(self._H),
z0=tf(self._z0),
f=tf(self._f),
)
return Domain(nx, ny, **kwargs)
[docs]
def extract_rectangle(
self, istart: int, istop: int, jstart: int, jstop: int
) -> "Domain":
"""Return a copy of the domain corresponding to the specified rectangle.
Indices must be provided for the T grid. Negative indices can be used
to indicate positions relative to the end of the domain. For example,
istart=1, istop=-1, jstart=1, jstop=-1 removes the outermost strip of cells,
and thus shrinks the domain by two cells in both x- and y-direction.
Args:
istart: lower x index (first that is included)
istop: upper x index (first that is EXcluded)
jstart: lower y index (first that is included)
jstop: upper y index (first that is EXcluded)
Returns:
Domain corresponding to the specified rectangle
"""
assert istart >= -self.nx and istart < self.nx
assert istop > -self.nx and istop <= self.nx
assert jstart >= -self.ny and jstart < self.ny
assert jstop > -self.ny and jstop <= self.ny
istart %= self.nx
istop %= self.nx
jstart %= self.ny
jstop %= self.ny
slc = (slice(2 * jstart, 2 * jstop + 1), slice(2 * istart, 2 * istop + 1))
def extract(array):
return None if array is None else array[slc]
subdomain = self._transform(istop - istart, jstop - jstart, extract)
for b in self.open_boundaries:
if b.side in (open_boundaries.Side.NORTH, open_boundaries.Side.SOUTH):
moffset, loffset = istart, jstart
else:
moffset, loffset = jstart, istart
subdomain.open_boundaries.add_by_index(
b.side,
b.l - loffset,
b.mstart - moffset,
b.mstop - moffset,
type_2d=b.type_2d,
type_3d=b.type_3d,
)
for r in self.rivers.values():
x, y = r.x, r.y
if r.coordinate_type == CoordinateType.IJ:
if x < istart or x >= istop or y < jstart or y >= jstop:
continue
x -= istart
y -= jstart
subdomain.rivers.add_by_location(
r.name, x, y, r.coordinate_type, zl=r.zl, zu=r.zu, **r.attrs
)
return subdomain
[docs]
def rotate(self) -> "Domain":
"""Return a copy of the domain rotated 90° clockwise"""
def tp(array):
return None if array is None else np.transpose(array)[::-1, :]
rotated_domain = self._transform(self.ny, self.nx, tp)
MAP = {
open_boundaries.Side.WEST: open_boundaries.Side.NORTH,
open_boundaries.Side.NORTH: open_boundaries.Side.EAST,
open_boundaries.Side.EAST: open_boundaries.Side.SOUTH,
open_boundaries.Side.SOUTH: open_boundaries.Side.WEST,
}
for b in self.open_boundaries:
mstart, mstop, l = b.mstart, b.mstop, b.l
if b.side in (open_boundaries.Side.NORTH, open_boundaries.Side.SOUTH):
mstart, mstop = (self.nx - 1 - mstart, self.nx - 1 - mstop)
else:
l = self.nx - 1 - l
rotated_domain.open_boundaries.add_by_index(
MAP[b.side], l, mstart, mstop, type_2d=b.type_2d, type_3d=b.type_3d
)
for r in self.rivers.values():
x, y = r.x, r.y
if r.coordinate_type == CoordinateType.IJ:
x, y = y, self.nx - 1 - x
rotated_domain.rivers.add_by_location(
r.name, x, y, r.coordinate_type, zl=r.zl, zu=r.zu, **r.attrs
)
return rotated_domain
[docs]
@apply_on_root_and_bcast
def nearest_point(
self,
x: float,
y: float,
*,
coordinate_type: Optional[CoordinateType] = None,
valid_cell_types: Iterable[CellType] = (CellType.ACTIVE,),
) -> tuple[int, int]:
if coordinate_type is None:
coordinate_type = self.coordinate_type
return core.Locator(**self._tcoords(mask=self._mask))(
x, y, coordinate_type=coordinate_type, valid_cell_types=valid_cell_types
)
[docs]
@apply_on_root_and_bcast
def contains(
self,
x: float,
y: float,
coordinate_type: Optional[CoordinateType] = None,
) -> bool:
"""Determine whether the domain contains the specified point.
Args:
x: x coordinate
y: y coordinate
coordinate_type: coordinate system of the provided coordinates
Returns:
True if the point falls within the domain, False otherwise
"""
if coordinate_type is None:
coordinate_type = self.coordinate_type
if coordinate_type == CoordinateType.LONLAT:
allx, ally = self._lon, self._lat
elif coordinate_type == CoordinateType.XY:
allx, ally = self._x, self._y
else:
return x >= 0 and x < self.nx and y >= 0 and y < self.ny
ny, nx = allx.shape
def get_boundary(c):
return np.concatenate(
(c[0, :-1], c[:-1, -1], c[-1, nx - 1 : 0 : -1], c[ny - 1 : 0 : -1, 0])
)
# Determine whether point falls within current subdomain
# based on https://wrf.ecse.rpi.edu/Research/Short_Notes/pnpoly.html
x_bnd = get_boundary(allx)
y_bnd = get_boundary(ally)
assert not np.isnan(x_bnd).any(), f"Invalid x boundary: {x_bnd}."
assert not np.isnan(y_bnd).any(), f"Invalid y boundary: {y_bnd}."
assert x_bnd.size == 2 * ny + 2 * nx - 4
inside = False
for i, (vertxi, vertyi) in enumerate(zip(x_bnd, y_bnd)):
vertxj, vertyj = x_bnd[i - 1], y_bnd[i - 1]
if (vertyi > y) != (vertyj > y) and x < (vertxj - vertxi) * (y - vertyi) / (
vertyj - vertyi
) + vertxi:
inside = not inside
return inside
[docs]
def plot(
self,
field: Optional[np.ndarray] = None,
*,
fig: Optional["matplotlib.figure.Figure"] = None,
show_bathymetry: bool = True,
show_mask: bool = False,
show_mesh: bool = False,
show_rivers: bool = True,
show_open_boundaries: bool = True,
show_subdomains: bool = False,
editable: bool = False,
coordinate_type: Optional[CoordinateType] = None,
tiling: Optional[parallel.Tiling] = None,
label: Optional[str] = None,
cmap: Union[None, "matplotlib.colors.Colormap", str] = None,
) -> Optional["matplotlib.figure.Figure"]:
"""Plot the domain, optionally including bathymetric depth, mesh and
river positions.
Args:
field: 2D field to plot. it must be defined on the supergrid.
fig: :class:`matplotlib.figure.Figure` instance to plot to.
If not provided, a new figure is created.
show_bathymetry: show bathymetry as color map
show_mask: show mask as color map (this disables ``show_bathymetry``)
show_mesh: show model grid
show_rivers: show rivers with position and name
show_subdomains: show subdomain decompositon
editable: allow interactive selection of rectangular regions in the domain
plot that are subsequently masked out
coordinate_type: coordinates to use for x and y axes (x/y, lon/lat, or i/j)
tiling: subdomain decomposition. This must be provided if `show_subdomains`
is set
label: label for the colorbar of the plotted field
cmap: colormap to use
Returns:
:class:`matplotlib.figure.Figure` instance for processes with rank 0,
otherwise ``None``
"""
if self.comm.rank != 0:
return
import matplotlib.pyplot
import matplotlib.collections
import matplotlib.widgets
if fig is None:
fig, ax = matplotlib.pyplot.subplots(
figsize=(0.15 * self.nx, 0.15 * self.ny)
)
else:
ax = fig.gca()
if coordinate_type is None:
coordinate_type = self.coordinate_type
if coordinate_type == CoordinateType.LONLAT:
x, y = (self._lon, self._lat)
xlabel, ylabel = "longitude (°East)", "latitude (°North)"
elif coordinate_type == CoordinateType.XY:
x, y = (self._x, self._y)
xlabel, ylabel = "x (m)", "y (m)"
else:
x = -0.5 + 0.5 * np.arange(1 + self.nx * 2)
y = -0.5 + 0.5 * np.arange(1 + self.ny * 2)
x, y = np.broadcast_arrays(x, y[:, np.newaxis])
xlabel, ylabel = "cell index", "cell index"
if show_mask or show_rivers:
mask = self.get_final_mask()
if field is None:
if show_mask:
field = mask
label = "mask value"
elif show_bathymetry:
import cmocean
cmap = cmocean.cm.deep
cmap.set_bad("gray")
field = np.ma.array(self._H, mask=self._mask == 0)
label = "undisturbed water depth (m)"
c = ax.pcolormesh(
x,
y,
field,
alpha=0.5 if show_mesh else 1,
shading="auto",
cmap=cmap,
)
# c = ax.contourf(x, y, np.ma.array(self.H, mask=self.mask==0), 20, alpha=0.5 if show_mesh else 1)
cb = fig.colorbar(c)
if label is not None:
cb.set_label(label)
if show_rivers and self.rivers:
self.rivers.map_to_grid(core.Locator(**self._tcoords(mask=mask)))
for name, river in self.rivers.items():
i_sup, j_sup = 1 + river.i * 2, 1 + river.j * 2
river_x, river_y = x[j_sup, i_sup], y[j_sup, i_sup]
ax.plot([river_x], [river_y], ".r")
ax.text(river_x, river_y, name, color="r")
if show_open_boundaries:
x_X, y_X = x[::2, ::2], y[::2, ::2]
for b in self.open_boundaries:
if b.side in (open_boundaries.Side.WEST, open_boundaries.Side.EAST):
imin, imax = b.l, b.l + 1
jmin = min(b.mstart, b.mstop - b.mstep)
jmax = max(b.mstart, b.mstop - b.mstep) + 1
else:
jmin, jmax = b.l, b.l + 1
imin = min(b.mstart, b.mstop - b.mstep)
imax = max(b.mstart, b.mstop - b.mstep) + 1
x_b = x_X[jmin : jmax + 1, imin : imax + 1]
y_b = y_X[jmin : jmax + 1, imin : imax + 1]
if x_b.shape[1] == 2:
x_b, y_b = x_b.T, y_b.T
x_b = np.hstack((x_b[0, :], x_b[1, ::-1]))
y_b = np.hstack((y_b[0, :], y_b[1, ::-1]))
ax.fill(x_b, y_b, alpha=0.3, fc="r", ec="None")
ax.fill(x_b, y_b, fc="None", ec="r")
def plot_mesh(ax, x, y, **kwargs):
segs1 = np.stack((x, y), axis=2)
segs2 = segs1.transpose(1, 0, 2)
ax.add_collection(matplotlib.collections.LineCollection(segs1, **kwargs))
ax.add_collection(matplotlib.collections.LineCollection(segs2, **kwargs))
if show_mesh:
plot_mesh(
ax, x[::2, ::2], y[::2, ::2], colors="w", linestyle="-", linewidth=0.5
)
plot_mesh(
ax, x[::2, ::2], y[::2, ::2], colors="k", linestyle="-", linewidth=0.3
)
# ax.pcolor(x[1::2, 1::2], y[1::2, 1::2], np.ma.array(x[1::2, 1::2], mask=True), edgecolors='k', linestyles='--', linewidth=.2)
# pc = ax.pcolormesh(x[1::2, 1::2], y[1::2, 1::2], np.ma.array(x[1::2, 1::2], mask=True), edgecolor='gray', linestyles='--', linewidth=.2)
# ax.plot(x[::2, ::2], y[::2, ::2], "xk", markersize=3.0)
# ax.plot(x[1::2, 1::2], y[1::2, 1::2], ".k", markersize=2.5)
if show_subdomains:
assert tiling is not None
for icol in range(tiling.ncol + 1):
i = 2 * (tiling.xoffset_global + icol * tiling.nx_sub)
if i >= 0 and i < x.shape[-1]:
ax.plot(x[:, i], y[:, i], "-k", linewidth=2.0)
ax.plot(x[:, i], y[:, i], "--w", linewidth=2.0)
for irow in range(tiling.nrow + 1):
j = 2 * (tiling.yoffset_global + irow * tiling.ny_sub)
if j >= 0 and j < x.shape[-2]:
ax.plot(x[j, :], y[j, :], "-k", linewidth=2.0)
ax.plot(x[j, :], y[j, :], "--w", linewidth=2.0)
# Rank of each subdomain (cross for subdomains that are not used, e.g. land)
for irow in range(tiling.nrow):
for icol in range(tiling.ncol):
imin = 2 * (tiling.xoffset_global + icol * tiling.nx_sub)
imax = imin + 2 * tiling.nx_sub
jmin = 2 * (tiling.yoffset_global + irow * tiling.ny_sub)
jmax = jmin + 2 * tiling.ny_sub
imin, imax = max(0, imin), min(x.shape[-1] - 1, imax)
jmin, jmax = max(0, jmin), min(x.shape[-2] - 1, jmax)
imid = (imin + imax) // 2
jmid = (jmin + jmax) // 2
rank = tiling.map[irow, icol]
if rank >= 0:
ax.text(
x[jmid, imid],
y[jmid, imid],
str(rank),
horizontalalignment="center",
verticalalignment="center",
)
else:
ax.plot(
[x[jmin, imin], x[jmax, imax]],
[y[jmin, imin], y[jmax, imax]],
"-k",
)
ax.plot(
[x[jmin, imax], x[jmax, imin]],
[y[jmin, imax], y[jmax, imin]],
"-k",
)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
if coordinate_type != CoordinateType.LONLAT:
ax.axis("equal")
xmin, xmax = np.nanmin(x), np.nanmax(x)
ymin, ymax = np.nanmin(y), np.nanmax(y)
xmargin = 0.05 * (xmax - xmin)
ymargin = 0.05 * (ymax - ymin)
ax.set_xlim(xmin - xmargin, xmax + xmargin)
ax.set_ylim(ymin - ymargin, ymax + ymargin)
def on_select(eclick, erelease):
xmin, xmax = (
min(eclick.xdata, erelease.xdata),
max(eclick.xdata, erelease.xdata),
)
ymin, ymax = (
min(eclick.ydata, erelease.ydata),
max(eclick.ydata, erelease.ydata),
)
self.mask_rectangle(xmin, xmax, ymin, ymax)
c.set_array(np.ma.array(self._H, mask=self._mask == 0).ravel())
fig.canvas.draw()
# self.sel.set_active(False)
# self.sel = None
# ax.draw()
# fig.clf()
# self.plot(fig=fig, show_mesh=show_mesh)
if editable:
self.sel = matplotlib.widgets.RectangleSelector(
ax, on_select, useblit=True, button=[1], interactive=False
)
return fig
[docs]
@apply_only_on_root
def to_xarray(self) -> xr.Dataset:
"""Convert domain to :class:`xarray.Dataset`.
The result can be loaded into a domain with :func:`from_xarray`.
Returns:
xarray Dataset with domain data
"""
def collect(*names, **kwargs) -> dict[str, xr.DataArray]:
result = {}
for name in names:
values = getattr(self, name)
if values is not None:
result[name] = xr.DataArray(values, dims=("y", "x"), **kwargs)
return result
coords = collect("x", "lon", attrs={"axis": "X"})
coords.update(collect("y", "lat", attrs={"axis": "Y"}))
data_vars = collect("mask", "H", "z0", "f")
attrs: dict[str, Any] = {"coordinate_type": self.coordinate_type.name}
if self.periodic_x:
attrs["periodic_x"] = 1
if self.periodic_y:
attrs["periodic_y"] = 1
return xr.Dataset(data_vars=data_vars, coords=coords, attrs=attrs)