Source code for pygetm.output.netcdf

from typing import Optional, Mapping, Union
import os
import logging
import datetime
import sys
from pathlib import Path
import contextlib

import cftime
import netCDF4

from . import File
from . import operators
import pygetm.core
import pygetm._pygetm


def _create_file(path: Union[os.PathLike, str], **kwargs) -> netCDF4.Dataset:
    nc = netCDF4.Dataset(path, "w", **kwargs)
    now = datetime.datetime.now()
    cmdline = " ".join(sys.argv)
    nc.history = f"{now:%Y-%m-%d %H:%M:%S} {cmdline}"
    nc.source = f"pygetm {pygetm._pygetm.get_version()}"
    return nc


def _create_dimensions(
    nc: netCDF4.Dataset, fields: Mapping[str, operators.Base]
) -> tuple[bool, bool]:
    needs_time = False
    needs_time_bounds = False
    for output_name, field in fields.items():
        for dim, length in zip(field.dims, field.shape):
            if dim not in nc.dimensions:
                assert length > 0
                nc.createDimension(dim, length)
            elif length != nc.dimensions[dim].size:
                raise Exception(
                    f"Error adding {output_name} with shape {field.shape}:"
                    f" existing dimension {dim} has incompatible length"
                    f" {nc.dimensions[dim].size} (needs {length})"
                )
        needs_time |= bool(field.time_varying)
        needs_time_bounds |= "time: mean" in field.attrs.get("cell_methods", "")

    if needs_time_bounds:
        nc.createDimension("nv", 2)
    if needs_time:
        nc.createDimension("time", None)
    return needs_time, needs_time_bounds


def _add_time_coordinate(
    nc: netCDF4.Dataset, attrs: dict[str, str]
) -> netCDF4.Variable:
    nctime = nc.createVariable("time", float, ("time",))
    for att, value in attrs.items():
        setattr(nctime, att, value)
    return nctime


def _add_time_bounds(nc: netCDF4.Dataset, nctime: netCDF4.Variable) -> netCDF4.Variable:
    nctime_ave = nc.createVariable("time_av", float, ("time",))
    nctime_ave.coordinates = nctime_ave.name
    nctime_bnds = nc.createVariable("time_bnds", float, ("time", "nv"))
    nctime_ave.bounds = nctime_bnds.name
    for att in ("units", "calendar"):
        if hasattr(nctime, att):
            value = getattr(nctime, att)
            setattr(nctime_ave, att, value)
            setattr(nctime_bnds, att, value)
    return nctime_ave, nctime_bnds


def _add_variable(
    nc: netCDF4.Dataset, name: str, field: operators.Base, **kwargs
) -> netCDF4.Variable:
    dims = field.dims
    if field.time_varying:
        dims = ("time",) + dims
    ncvar = nc.createVariable(
        name, field.dtype, dims, fill_value=field.fill_value, **kwargs
    )

    # Only some NetCDF engines support set_auto_maskandscale
    # (netCDF4 does, h5netcdf.legacyapi does not)
    with contextlib.suppress(AttributeError):
        ncvar.set_auto_maskandscale(False)

    # Variable attributes
    ncvar.expression = field.expression
    for att, value in field.attrs.items():
        setattr(ncvar, att, value)
    coords = field.coordinates
    if "time: mean" in field.attrs.get("cell_methods", ""):
        coords = coords + ["time_av"]
    if coords:
        ncvar.coordinates = " ".join(coords)

    return ncvar


[docs] class NetCDFFile(File): def __init__( self, available_fields: Mapping[str, pygetm.core.Array], logger: logging.Logger, path: Union[os.PathLike[str], str], rank: int, sync_interval: Optional[int] = 1, time_reference: Optional[cftime.datetime] = None, format: str = "NETCDF4", compression: Optional[str] = None, **kwargs, ): """Create a NetCDF file for output Args: available_fields: collection of model fields that may be added logger: target for log messages path: file to create. If it exists it will be clobbered. rank: rank of this subdomain. This will be used to determine whether we are the root (rank 0) all output is gathered and written to a single file. Otherwise the rank will be used as suffix for the subdomain-specific files. sync_interval: frequency to call NetCDF sync, which forces all output to be written to disk. If set to None, synchronization will happen only when the file is closed as the end of a simulation. time_reference: time reference (epoch) to use as offset for the time coordinate. format: underlying file format (see :class:`netCDF4.Dataset` documentation) compression: compression algorithm to apply to all variables (see :meth:`netCDF4.Dataset.createVariable` documentation) **kwargs: additional keyword arguments passed to :class:`pygetm.output.File` """ super().__init__(available_fields, logger, **kwargs) path = Path(path) if self.sub: path = path.with_stem(f"{path.stem}_{rank:05}") self.path = path self.nc: Optional[netCDF4.Dataset] = None self.itime = 0 self.is_root = rank == 0 self.time_offset = 0.0 self.time_reference = time_reference self.sync_interval = sync_interval self.format = format self.compression = compression self._field2nc: dict[operators.Base, netCDF4.Variable] = {} self._varying_fields: list[operators.Base] = [] self.nctime: Optional[netCDF4.Variable] = None self.nctime_bnds: Optional[netCDF4.Variable] = None def __repr__(self) -> str: return f"{self.__class__.__name__}('{self.path}')"
[docs] def start_now( self, seconds_passed: float, time: Optional[cftime.datetime], default_time_reference: Optional[cftime.datetime], ) -> bool: if self.is_root or self.sub: included_fields = self.select_nonempty_fields() if included_fields: self.nc = _create_file(self.path, format=self.format) has_time, has_time_bounds = _create_dimensions(self.nc, included_fields) if has_time: time_reference = self.time_reference or default_time_reference attrs, self.time_offset = self.get_cf_time_attrs( time, seconds_passed, time_reference ) self.nctime = _add_time_coordinate(self.nc, attrs) if has_time_bounds: self.nctime_ave, self.nctime_bnds = _add_time_bounds( self.nc, self.nctime ) self.previous_time_coord = self.time_offset + seconds_passed for output_name, field in included_fields.items(): self._field2nc[field] = _add_variable( self.nc, output_name, field, compression=self.compression ) for field in self.fields.values(): if field.time_varying: # Store field for update at each time step self._varying_fields.append(field) else: # Write static field now field.get(self._field2nc.get(field)) return len(self._varying_fields) > 0
[docs] def save_now(self, seconds_passed: float, time: Optional[cftime.datetime]): # Update time coordinate(s), if used if self.nctime is not None: time_coord = self.time_offset + seconds_passed self.nctime[self.itime] = time_coord if self.nctime_bnds is not None: self.nctime_ave[self.itime] = 0.5 * ( self.previous_time_coord + time_coord ) self.nctime_bnds[self.itime, :] = [self.previous_time_coord, time_coord] self.previous_time_coord = time_coord # Update time-varying fields for field in self._varying_fields: field.get(self._field2nc.get(field), slice_spec=(self.itime,)) # Increment time index and sync to disk if needed self.itime += 1 if ( self.nc is not None and self.sync_interval is not None and self.itime % self.sync_interval == 0 ): self.nc.sync()
[docs] def close_now(self, seconds_passed: float, time: Optional[cftime.datetime]): if self.nc is not None: self.nc.close()