Source code for pygetm.util.interpolate

from typing import Optional
import numpy as np
import numpy.typing as npt

from pygetm.constants import EdgeTreatment


[docs] class Linear2DGridInterpolator: def __init__( self, x: npt.ArrayLike, y: npt.ArrayLike, xp: npt.ArrayLike, yp: npt.ArrayLike, preslice=(Ellipsis,), ndim_trailing: int = 0, mask: Optional[npt.ArrayLike] = None, ): assert ndim_trailing >= 0 xp = np.array(xp, dtype=float) yp = np.array(yp, dtype=float) x = np.array(x, dtype=float) y = np.array(y, dtype=float) assert xp.ndim == 1, f"source x coordinate must be 1D but has shape {xp.shape}" assert yp.ndim == 1, f"source y coordinate must be 1D but has shape {yp.shape}" self.nxp, self.nyp = xp.size, yp.size assert ( self.nxp > 1 ), f"source x coordinate must have length > 1, but has length {self.nxp}" assert ( self.nyp > 1 ), f"source y coordinate must have length > 1, but has length {self.nyp}" x, y = np.broadcast_arrays(x, y) dxp = xp[1:] - xp[:-1] dyp = yp[1:] - yp[:-1] assert (dxp > 0).all() or ( dxp < 0 ).all(), "source x coordinate must be monotonically increasing or decreasing" assert (dyp > 0).all() or ( dyp < 0 ).all(), "source y coordinate must be monotonically increasing or decreasing" if dxp[0] < 0: # reversed source x xp = xp[::-1] if dyp[0] < 0: # reversed source y yp = yp[::-1] assert (x >= xp[0]).all() and (x <= xp[-1]).all(), ( f"One or more target x coordinates ({x.min()} - {x.max()})" f" fall outside of source range ({xp[0]} - {xp[-1]})" ) assert (y >= yp[0]).all() and (y <= yp[-1]).all(), ( f"One or more target y coordinates ({y.min()} - {y.max()})" f" fall outside of source range ({yp[0]} - {yp[-1]})" ) ix_right = np.minimum(xp.searchsorted(x, side="right"), xp.size - 1) ix_left = ix_right - 1 iy_right = np.minimum(yp.searchsorted(y, side="right"), yp.size - 1) iy_left = iy_right - 1 wx_left = (xp[ix_right] - x) / (xp[ix_right] - xp[ix_left]) wy_left = (yp[iy_right] - y) / (yp[iy_right] - yp[iy_left]) self.w11 = wx_left * wy_left self.w12 = wx_left * (1.0 - wy_left) self.w21 = (1.0 - wx_left) * wy_left self.w22 = (1.0 - wx_left) * (1.0 - wy_left) assert np.allclose(self.w11 + self.w12 + self.w21 + self.w22, 1.0, 1e-14) # Ensure weights are broadcastable to shape of data array wshape = x.shape + (1,) * ndim_trailing self.w11 = np.reshape(self.w11, wshape) self.w12 = np.reshape(self.w12, wshape) self.w21 = np.reshape(self.w21, wshape) self.w22 = np.reshape(self.w22, wshape) # If we reversed source coordinates, compute the correct indices if dxp[0] < 0: ix_left, ix_right = xp.size - ix_left - 1, xp.size - ix_right - 1 if dyp[0] < 0: iy_left, iy_right = yp.size - iy_left - 1, yp.size - iy_right - 1 # Store slices into data array self.slice11 = (Ellipsis, ix_left, iy_left) + (slice(None),) * ndim_trailing self.slice12 = (Ellipsis, ix_left, iy_right) + (slice(None),) * ndim_trailing self.slice21 = (Ellipsis, ix_right, iy_left) + (slice(None),) * ndim_trailing self.slice22 = (Ellipsis, ix_right, iy_right) + (slice(None),) * ndim_trailing if mask is not None: # Force weights to zero for masked points and renormalize weights # so their sum is 1 mask = np.asarray(mask) assert ( mask.ndim >= 2 + ndim_trailing ), f"Mask should have at least {2 + ndim_trailing} dimensions" ndim_no_trail = mask.ndim - ndim_trailing mask_xy_shape = mask.shape[ndim_no_trail - 2 : ndim_no_trail] xy_shape = (self.nxp, self.nyp) assert ( mask_xy_shape == xy_shape ), f"Bad mask shape for x, y: {mask_xy_shape} while expected {xy_shape}" target_shape = ( mask.shape[: ndim_no_trail - 2] + x.shape + mask.shape[ndim_no_trail:] ) use_nn = ( mask[self.slice11] | mask[self.slice12] | mask[self.slice21] | mask[self.slice22] ) if use_nn.any(): import scipy.spatial # Build kd-tree for nearest-neighbor lookup, using only unmasked points if mask.all(): raise ValueError("All source points are masked, cannot interpolate") source_coords = np.indices(mask.shape)[:, ~mask] tree = scipy.spatial.KDTree(source_coords.T) ix = ix_right - wx_left iy = iy_right - wy_left sl = (Ellipsis,) + (np.newaxis,) * ndim_trailing target_indices = np.indices(target_shape) target_coords = [] for i in range(len(target_shape) - 2 - ndim_trailing): target_coords.append(target_indices[i]) target_coords.append(np.broadcast_to(ix[sl], target_shape)) target_coords.append(np.broadcast_to(iy[sl], target_shape)) for i in range(ndim_trailing): target_coords.append(target_indices[-ndim_trailing + i]) target_coords = np.array(target_coords, dtype=float) _, inearest = tree.query(target_coords[:, use_nn].T, workers=-1) for k in ("slice11", "slice12", "slice21", "slice22"): s = getattr(self, k) # Broadcast original slices slc_indices = [] for i in range(len(target_shape) - 2 - ndim_trailing): slc_indices.append(target_indices[i]) ix = np.broadcast_to(s[-2 - ndim_trailing][sl], target_shape) iy = np.broadcast_to(s[-1 - ndim_trailing][sl], target_shape) slc_indices.extend([ix, iy]) for i in range(ndim_trailing): slc_indices.append(target_indices[-ndim_trailing + i]) # Substitute nearest for i in range(len(slc_indices)): slc_indices[i] = np.array(slc_indices[i]) slc_indices[i][use_nn] = source_coords[i, inearest] setattr(self, k, (Ellipsis,) + tuple(slc_indices)) assert not mask[self.slice11].any() assert not mask[self.slice12].any() assert not mask[self.slice21].any() assert not mask[self.slice22].any() self.idim1 = -2 - ndim_trailing self.idim2 = -1 - ndim_trailing def __call__(self, fp: np.ndarray) -> np.ndarray: assert fp.shape[self.idim1] == self.nxp assert fp.shape[self.idim2] == self.nyp result = self.w11 * fp[self.slice11] result += self.w12 * fp[self.slice12] result += self.w21 * fp[self.slice21] result += self.w22 * fp[self.slice22] return result
[docs] class LinearVectorized1D: """One-dimensional linear interpolation along a given axis, for nD source coordinates and 1D target coordinates. For instance, to go from 3D depths to a z grid (1D) """ def __init__( self, x: npt.ArrayLike, xp: npt.ArrayLike, axis: int = 0, fill_value: float = np.nan, mask: Optional[npt.ArrayLike] = None, edges: EdgeTreatment = EdgeTreatment.MISSING, ): """Initialize the interpolator. It can subsequently be called multiple times with different source values but the same coordinates. Args: x: Target coordinate values (1D) xp: Source coordinate values (nD) axis: Axis along which to interpolate fill_value: Value to use for out-of-bounds target coordinates (if `edges` is `MISSING`) and for locations where there are no valid source points along the interpolated dimension (if `mask` is given) mask: Optional boolean array of the same shape as `xp` indicating masked (invalid) source points. It is currently only used to detect locations where there are no valid source points along the interpolated dimension. There, `fill_value` will be used independent of the `edges` setting. edges: How to treat target coordinates that fall outside the range of valid source coordinates. If `MISSING`, these will be assigned `fill_value`. If `CLAMP`, these will be assigned the nearest valid source value. """ x = np.asarray(x, dtype=float) xp = np.asarray(xp, dtype=float) assert x.ndim == 1 assert axis >= -xp.ndim and axis < xp.ndim xp_slice = [slice(None)] * xp.ndim final_shape = list(xp.shape) final_shape[axis] = x.size ix_left = np.empty(final_shape, dtype=np.intp) if mask is not None: masked_xp = np.broadcast_to(np.asarray(mask, dtype=bool), xp.shape) any_valid_xp = ~masked_xp.all(axis=axis) # start_skip = np.logical_and.accumulate(masked_xp, axis=axis).sum(axis=axis) # stop_skip = np.logical_and.accumulate( # np.flip(masked_xp, axis=axis), axis=axis # ).sum(axis=axis) # assert (start_skip + stop_skip == masked_xp.sum(axis=axis)).all( # where=any_valid_xp # ) for ix, cur_x in enumerate(x): xp_slice[axis] = ix ix_left_cur = (xp < cur_x).sum(axis=axis) - 1 ix_left_cur += (xp == cur_x).any(axis=axis) ix_left[tuple(xp_slice)] = ix_left_cur if (np.diff(xp, axis=axis) >= 0.0).all(): # Source coordinate is monotonically INcreasing ix_right = np.minimum(ix_left + 1, xp.shape[axis] - 1) ix_left = np.maximum(ix_left, 0) else: # Source coordinate is monotonically DEcreasing ix_right = xp.shape[axis] - 1 - ix_left ix_left = np.maximum(ix_right - 1, 0) ix_right = np.minimum(ix_right, xp.shape[axis] - 1) valid = ix_left != ix_right xp_right = np.take_along_axis(xp, ix_right, axis=axis) xp_left = np.take_along_axis(xp, ix_left, axis=axis) dxp = xp_right - xp_left x_shape = [1] * xp.ndim x_shape[axis] = x.size x_bc = x.reshape(x_shape) w_left = np.ones(xp_right.shape) np.divide(xp_right - x_bc, dxp, out=w_left, where=valid) self.ix_left = ix_left self.ix_right = ix_right self.w_left = w_left self.axis = axis if edges == EdgeTreatment.MISSING: self.valid = valid elif mask is not None: self.valid = any_valid_xp else: self.valid = True self.fill_value = fill_value assert edges in (EdgeTreatment.MISSING, EdgeTreatment.CLAMP) self.edges = edges def __call__(self, yp) -> np.ndarray: yp = np.asarray(yp) yp_left = np.take_along_axis(yp, self.ix_left, axis=self.axis) yp_right = np.take_along_axis(yp, self.ix_right, axis=self.axis) y = self.w_left * yp_left + (1.0 - self.w_left) * yp_right y = np.where(self.valid, y, self.fill_value) return y
[docs] def interp_1d( x: npt.ArrayLike, xp: npt.ArrayLike, fp: npt.ArrayLike, axis: int = 0 ) -> np.ndarray: """One-dimensional linear interpolation along a given axis for 1D source coordinates and nD target coordinates. For instance, to interpolate from 3D values defined at z coordinates (1D) to 3D values at depth coordinates that vary in the horizontal. Source values may contain NaNs or masked values at the beginning or end of the interpolated dimension; these will be skipped during interpolation. Where target coordinates fall outside the range of valid source coordinates, the corresponding output values will be equal to the nearest valid source value. Args: x: Target coordinate values (nD, matching shape of fp except at `axis`) xp: Source coordinate values (1D) fp: Source values to interpolate (nD, matching the size of xp at `axis`) axis: Axis along which to interpolate Returns: Interpolated values at target coordinates """ x = np.asarray(x, dtype=float) xp = np.asarray(xp, dtype=float) fp = np.ma.filled(fp, np.nan) if fp.ndim != x.ndim: raise ValueError( f"Number of dimensions {fp.ndim} of source values" f" does not match {x.ndim} of target coordinate." ) if xp.ndim != 1: raise ValueError(f"Source coordinate must be 1D but has shape {xp.shape}.") if fp.shape[:axis] != x.shape[:axis] or fp.shape[axis + 1 :] != x.shape[axis + 1 :]: raise ValueError( f"Shapes of source values {fp.shape} and target coordinate {x.shape}" f" should match everywhere except at the interpolated dimension ({axis})" ) if fp.shape[axis] != xp.shape[0]: raise ValueError( f"Size of source values {fp.shape[axis]} and source coordinate {xp.shape[0]}" f" must match at the interpolated dimension ({axis})" ) dxp = xp[1:] - xp[:-1] if not ((dxp > 0).all() or (dxp < 0).all()): raise ValueError("xp must be monotonically increasing or decreasing") invalid = np.isnan(fp) if invalid.any(): # Source values include NaNs. Identify the inner valid (NaN-free) region. valid = ~invalid s = tuple(np.newaxis if i != axis else slice(None) for i in range(fp.ndim)) ind = np.broadcast_to(np.arange(fp.shape[axis])[s], fp.shape) first = ind.min(axis=axis, where=valid, initial=fp.shape[axis], keepdims=True) last = ind.max(axis=axis, where=valid, initial=0, keepdims=True) first = np.minimum(first, last) # if no valid elements at all, first=last=0 else: first = 0 last = xp.size - 1 xp_reversed = dxp[0] < 0 if xp_reversed: # Monotonically DEcreasing source coordinate. Flip it to make it # monotonically INcreasing for lookup of bounding intervals xp = xp[::-1] dxp = -dxp[::-1] first, last = (xp.size - 1) - last, (xp.size - 1) - first # Look up upper bound of interval around each target coordinate # This will be 0 [invalid!] if first source coordinate < minimum target coordinate # This will be xp.size [invalid!] if last source coordinate >= maximum target # coordinate i_right = xp.searchsorted(x, side="right") # Determine intervals (left and right bounds). # These will be zero-width (i_left == i_right) at the boundaries i_left = i_right - 1 i_left.clip(first, last, out=i_left) i_right.clip(first, last, out=i_right) scale = np.zeros(x.shape, dtype=float) idxp = np.append(1.0 / dxp, 0.0) np.multiply(x - xp[i_left], idxp[i_left], where=i_left != i_right, out=scale) if xp_reversed: # Correct indices into fp for flipped source coordinate i_left, i_right = (xp.size - 1) - i_left, (xp.size - 1) - i_right f_left = np.take_along_axis(fp, i_left, axis=axis) f_right = np.take_along_axis(fp, i_right, axis=axis) return f_left + (f_right - f_left) * scale