Source code for pygetm.util.compare_nc

import sys
from typing import Container, Mapping
import numpy as np
import netCDF4


[docs] def compare( path1: str, path2: str, rotate: bool = False, check_constant: bool = True, name_map: Mapping[str, str] = {}, flip: Container[str] = (), tolerance: float = 0.0, ) -> bool: mismatches = [] def load(slc=(Ellipsis,)): values1 = ncvar1[slc] values2 = ncvar2[slc] ndim = len([d for d in ncvar1.dimensions if d not in ("z", "zi", "time")]) if rotate and ndim >= 2: values2 = np.swapaxes(values2, -1, -2)[..., ::-1, :] if ncvar1.dimensions[-2:] == ("yv", "xv") and ncvar2.dimensions[-2:] == ( "yu", "xu", ): values1 = values1[..., :-1, :] values2 = values2[..., 1:, :] if rotate and name in flip: values2 = -values2 return np.ma.masked_invalid(values1), np.ma.masked_invalid(values2) with netCDF4.Dataset(path1) as nc1, netCDF4.Dataset(path2) as nc2: for name in frozenset(nc1.variables) & frozenset(nc2.variables): ncvar1 = nc1.variables[name] ncvar2 = nc2.variables[name_map.get(name, name)] if "time" in ncvar1.dimensions: # Quick check for very last time point; if values match, # values at all preceding time points are assumed to match too. # Then skip the full values check (expensive!) values1, values2 = load((-1, Ellipsis)) if np.ma.filled(values1 == values2, True).all(): continue elif not check_constant: continue # Values at last time did not match. Now check all time points values1, values2 = load() mismatch = np.ma.filled(values1 != values2, False) maxval = max(np.ma.abs(values1).max(), np.ma.abs(values2).max()) idim_start = 1 if "time" in ncvar1.dimensions else 0 mismatch_count = mismatch.sum(axis=tuple(range(idim_start, mismatch.ndim))) absdiff = np.abs(np.ma.filled(values1 - values2, 0.0)) mismatch_max = absdiff.max(axis=tuple(range(idim_start, mismatch.ndim))) if mismatch_count.sum() > 0: ifirst_mismatch = mismatch_count.nonzero()[0][0] mismatches.append( (name, mismatch_count, ifirst_mismatch, mismatch_max / maxval) ) success = True for name, mismatch_count, ifirst, mismatch_max in sorted( mismatches, key=lambda x: x[2] ): msg = ( "%s: mismatches at %i of %i times, first at time=%i with %i mismatches (max = %.3e, final max = %.3e)" % ( name, (mismatch_count > 0).sum(), mismatch_count.size, ifirst, mismatch_count.flat[ifirst], mismatch_max.flat[ifirst], mismatch_max.flat[-1], ) ) if mismatch_max.flat[-1] < tolerance: msg += " - WITHIN TOLERANCE (%.3e)" % (tolerance,) else: success = False print(msg) return success
[docs] def compare_at_time( path1: str, path2: str, itime1: int, itime2: int, verbose: bool = False ) -> bool: success = True with netCDF4.Dataset(path1) as nc1, netCDF4.Dataset(path2) as nc2: for name in frozenset(nc1.variables) & frozenset(nc2.variables): ncvar1, ncvar2 = nc1.variables[name], nc2.variables[name] if "time" not in ncvar1.dimensions: continue values1 = np.ma.masked_invalid(ncvar1[itime1, ...]) values2 = np.ma.masked_invalid(ncvar2[itime2, ...]) mismatch = np.ma.filled(values1 != values2, False) if mismatch.any(): success = False iworst = np.ma.filled(np.abs(values1 - values2), 0.0).ravel().argmax() print( "%s: %i mismatches (out of %i)" " - maximum absolute difference %s vs %s" % ( name, mismatch.sum(), mismatch.size, values1.flat[iworst], values2.flat[iworst], ) ) elif verbose: print("%s: identical" % name) return success
[docs] def compare_command(): import argparse parser = argparse.ArgumentParser() parser.add_argument("file1") parser.add_argument("file2") parser.add_argument("name_map", nargs="*") # , action="append", default=[]) parser.add_argument("--itime1", type=int) parser.add_argument("--itime2", type=int) parser.add_argument("--rotate", action="store_true") parser.add_argument("-v", "--verbose", action="store_true") args = parser.parse_intermixed_args() if (args.itime1 is None and args.itime2 is not None) or ( args.itime1 is not None and args.itime2 is None ): print( "ERROR: --itime1 and --itime2 must either be absent, or both be provided." ) sys.exit(2) name_map = {} flip = set() for s in args.name_map: try: v1, v2 = s.split("=") except ValueError: raise Exception("Invalid assignment: %s" % s) if ".flip" in v2: flip.add(v1) name_map[v1] = v2.split(".", 1)[0] if args.itime1 is None: success = compare( args.file1, args.file2, args.rotate, name_map=name_map, flip=flip ) else: success = compare_at_time( args.file1, args.file2, args.itime1, args.itime2, args.verbose ) if not success: sys.exit(1)