From cacc9a967b6fa2eb738134516da02481550c515a Mon Sep 17 00:00:00 2001 From: s-scherrer Date: Tue, 30 Aug 2022 16:55:31 +0200 Subject: [PATCH] Smaller fixes and better tests, merging of private fork (#12) --- CHANGELOG.rst | 12 +++- src/qa4sm_preprocessing/reading/__init__.py | 4 +- src/qa4sm_preprocessing/reading/base.py | 6 +- src/qa4sm_preprocessing/reading/cli.py | 49 ++++----------- src/qa4sm_preprocessing/reading/image.py | 26 ++++---- src/qa4sm_preprocessing/reading/imagebase.py | 19 +++--- src/qa4sm_preprocessing/reading/stack.py | 4 +- src/qa4sm_preprocessing/reading/timeseries.py | 38 ++++++++---- src/qa4sm_preprocessing/reading/transpose.py | 18 ++++-- src/qa4sm_preprocessing/reading/write.py | 4 +- tests/conftest.py | 8 +-- tests/test_reading/test_cli.py | 8 +-- tests/test_reading/test_image_synthetic.py | 6 +- tests/test_reading/test_readers_realdata.py | 52 +++++++++------- tests/test_reading/test_timeseries.py | 60 +++++++++++++++---- tests/test_reading/test_transpose.py | 52 ++++++++++++++++ tests/test_reading/test_utils.py | 58 ++++++++++++++++++ 17 files changed, 298 insertions(+), 126 deletions(-) create mode 100644 tests/test_reading/test_transpose.py create mode 100644 tests/test_reading/test_utils.py diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 6489c64..e8a119a 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -5,7 +5,17 @@ Changelog Unreleased ========== -- +v0.1.2 +====== + +- more features for ``DirectoryImageReader`` + - handling of "image" files with multiple time steps + - better documentation for subclassing +- renaming of other readers + - ``XarrayImageStackReader`` to ``StackImageReader`` + - ``XarrayTSReader`` to ``StackTs`` +- ``repurpose`` function for image readers +- improved test coverage v0.1.0 ====== diff --git a/src/qa4sm_preprocessing/reading/__init__.py b/src/qa4sm_preprocessing/reading/__init__.py index afc52f8..1c5f9c2 100644 --- a/src/qa4sm_preprocessing/reading/__init__.py +++ b/src/qa4sm_preprocessing/reading/__init__.py @@ -1,4 +1,4 @@ from .image import DirectoryImageReader -from .stack import XarrayImageStackReader -from .timeseries import XarrayTSReader, GriddedNcOrthoMultiTs +from .stack import StackImageReader +from .timeseries import StackTs, GriddedNcOrthoMultiTs from .transpose import write_transposed_dataset diff --git a/src/qa4sm_preprocessing/reading/base.py b/src/qa4sm_preprocessing/reading/base.py index 9d063d2..9045c52 100644 --- a/src/qa4sm_preprocessing/reading/base.py +++ b/src/qa4sm_preprocessing/reading/base.py @@ -15,7 +15,7 @@ from .utils import mkdate -class XarrayReaderBase: +class ReaderBase: """ Base class for readers backed by xarray objects (images, image stacks, timeseries). @@ -239,7 +239,7 @@ def finalize_grid(self, grid): num_gpis = len(grid.activegpis) logging.debug(f"finalize_grid: Number of active gpis: {num_gpis}") - if hasattr(self, "cellsize"): + if hasattr(self, "cellsize"): # pragma: no branch if self.cellsize is None: # Automatically set a suitable cell size, aiming at cell sizes # of about 30**2 pixels. @@ -249,7 +249,7 @@ def finalize_grid(self, grid): grid = grid.to_cell_grid(cellsize=self.cellsize) num_cells = len(grid.get_cells()) logging.debug( - f"_grid_from_xarray: Number of grid cells: {num_cells}" + f"finalize_grid: Number of grid cells: {num_cells}" ) return grid diff --git a/src/qa4sm_preprocessing/reading/cli.py b/src/qa4sm_preprocessing/reading/cli.py index d3f74be..a38fa33 100644 --- a/src/qa4sm_preprocessing/reading/cli.py +++ b/src/qa4sm_preprocessing/reading/cli.py @@ -31,7 +31,7 @@ from repurpose.img2ts import Img2Ts -from . import XarrayImageStackReader, DirectoryImageReader +from . import StackImageReader, DirectoryImageReader from .transpose import write_transposed_dataset from .utils import mkdate, str2bool @@ -211,40 +211,24 @@ def __init__(self, description): default=True, help="Whether to use compression or not. Default is true", ) + self.add_argument( + "--memory", + type=float, + default=2, + help="The amount of memory to use as buffer in GB", + ) class RepurposeArgumentParser(ReaderArgumentParser): def __init__(self): super().__init__("Converts data to time series format.") self.prog = "repurpose_images" - self.add_argument( - "--imgbuffer", - type=int, - default=365, - help=( - "How many images to read at once. Bigger " - "numbers make the conversion faster but " - "consume more memory. Default is 365." - ), - ) - self.add_argument( - "--cellsize", - type=float, - default=5.0, - help=("Size of single file cells. Default is 5.0."), - ) class TransposeArgumentParser(ReaderArgumentParser): def __init__(self): super().__init__("Converts data to transposed netCDF.") self.prog = "transpose_images" - self.add_argument( - "--memory", - type=float, - default=2, - help="The amount of memory to use as buffer in GB", - ) self.add_argument( "--n_threads", type=int, @@ -300,7 +284,7 @@ def parse_args(parser, args): input_path = Path(args.dataset_root) if input_path.is_file(): - reader = XarrayImageStackReader( + reader = StackImageReader( input_path, args.parameter, **common_reader_kwargs, @@ -327,19 +311,12 @@ def repurpose(args): outpath = Path(args.output_root) outpath.mkdir(exist_ok=True, parents=True) - reshuffler = Img2Ts( - input_dataset=reader, - outputpath=args.output_root, - startdate=args.start, - enddate=args.end, - ts_attributes=reader.global_attrs, - zlib=args.zlib, - imgbuffer=args.imgbuffer, - # this is necessary currently due to bug in repurpose - cellsize_lat=args.cellsize, - cellsize_lon=args.cellsize, + reader.repurpose( + args.output_root, + start=args.start, + end=args.end, + memory=args.memory ) - reshuffler.calc() def transpose(args): diff --git a/src/qa4sm_preprocessing/reading/image.py b/src/qa4sm_preprocessing/reading/image.py index fbf00ac..a834359 100644 --- a/src/qa4sm_preprocessing/reading/image.py +++ b/src/qa4sm_preprocessing/reading/image.py @@ -5,7 +5,7 @@ The ``DirectoryImageReader`` aims to provide an easy to use class to read directories of single images, to either create a single image stack file, a transposed stack (via ``write_transposed_dataset``), or a cell-based timeseries -dataset (via the repurpose package). +dataset (via the ``repurpose`` method). The advantage over ``xr.open_mfdataset`` is that the reader typically does not need to open each dataset to get information on coordinates. Instead, the @@ -35,7 +35,7 @@ * ``_metadata_from_dataset``: If metadata cannot be read from the files. * ``_tstamps_in_file``: If the timestamp cannot be inferred from the filename but via other info specific to the dataset this can be used to avoid having - to reading all files only to get the timestamps. + to read all files only to get the timestamps. * ``_landmask_from_dataset``: If a landmask is required (only if the ``read`` function is used), and it cannot be read with ``_open_dataset`` and also not with other options. Should not be necessary very often. @@ -43,7 +43,8 @@ the data as dictionary that maps from variable names to 3d data arrays (numpy or dask). If it is hard to read the data as xr.Dataset, so that overriding `_open_dataset` is not feasible, this could be overriden instead, but then - all the other routines for obtaining metadata also have to be overriden. + all the other routines for obtaining grid/metadata/landmask info also have to + be overriden. In the following some examples for subclassing are provided. @@ -87,7 +88,7 @@ def __init__(self, directory): It is often necessary to preprocess the data before using it. For example, many datasets contain quality flags as additional variable that need to be applied to mask out unreliable data. Another example would be a case where one is -interested in a sum of multiple variables. In this case it is necessary to +interested in a sum of multiple variables. In these cases it is necessary to override the ``_open_dataset`` method. As an example, consider that we have images containing a field "soil_moisture" @@ -106,9 +107,9 @@ def __init__(self, directory): ) def _open_dataset(self, fname): - ds = super()_open_dataset(fname) + ds = super()._open_dataset(fname) qc = ds["quality_flag"] - # We check if the first bit is zero by doing a bitwise and with 1. + # We check if the first bit is zero by doing a bitwise AND with 1. # The result is 0 if the first bit is zero, and 1 otherwise. valid = (qc & 2**0) == 0 return ds[["soil_moisture"]].where(valid) @@ -151,8 +152,7 @@ def __init__(self, directory): fmt="%Y%m%d", time_regex_pattern=r"SMAP_L3_SM_P_([0-9]+)_R.*.h5", pattern="**/*.h5", - use_tqdm=True, # for nice progress bar - # + # there are 2 timestamps in each file timestamps=[pd.Timedelta("6H"), pd.Timedelta("18H")] ) @@ -198,7 +198,7 @@ def _latlon_from_dataset(self, fname): don't contain timestamps, it might be necessary to also override `_tstamps_in_file`, `_metadata_from_dataset`, or `_landmask_from_dataset`. Since these are edge cases, they are not shown in -detail here, but it works similarly to the other examples. +detail here, but it works similar to the other examples. """ import cftime @@ -214,12 +214,12 @@ def _latlon_from_dataset(self, fname): import warnings import xarray as xr -from .imagebase import XarrayImageReaderBase +from .imagebase import ImageReaderBase from .base import LevelSelectionMixin from .exceptions import ReaderError -class DirectoryImageReader(LevelSelectionMixin, XarrayImageReaderBase): +class DirectoryImageReader(LevelSelectionMixin, ImageReaderBase): r""" Image reader for a directory containing netcdf files. @@ -637,7 +637,7 @@ def _read_block( # If we have to average multiple images to a single image, we will # read image by image block_dict = {varname: [] for varname in self.varnames} - if self.use_tqdm: + if self.use_tqdm: # pragma: no branch times = tqdm(times) for tstamp in times: # read all sub-images that have to be averaged later on @@ -696,7 +696,7 @@ def _read_all_files(self, times, use_tqdm): # now we can open each file and extract the timestamps we need block_dict = {varname: [] for varname in self.varnames} iterator = file_tstamp_map.items() - if use_tqdm: + if use_tqdm: # pragma: no branch iterator = tqdm(iterator) for fname, tstamps in iterator: _blockdict = self._read_single_file(fname, tstamps) diff --git a/src/qa4sm_preprocessing/reading/imagebase.py b/src/qa4sm_preprocessing/reading/imagebase.py index f0528ea..aa4b984 100644 --- a/src/qa4sm_preprocessing/reading/imagebase.py +++ b/src/qa4sm_preprocessing/reading/imagebase.py @@ -3,6 +3,7 @@ import datetime import numpy as np from pathlib import Path +import shutil from typing import Union, Iterable, List, Tuple, Sequence, Dict import xarray as xr @@ -11,11 +12,11 @@ from .exceptions import ReaderError from .utils import mkdate, nimages_for_memory -from .base import XarrayReaderBase +from .base import ReaderBase from .timeseries import GriddedNcOrthoMultiTs -class XarrayImageReaderBase(XarrayReaderBase): +class ImageReaderBase(ReaderBase): """ Base class for image readers backed by xarray objects (multiple single images or single stack of multiple images). @@ -38,11 +39,11 @@ def _validate_start_end( ) -> Tuple[datetime.datetime]: if start is None: start = self.timestamps[0] - elif isinstance(start, str): + elif isinstance(start, str): # pragma: no cover start = mkdate(start) if end is None: end = self.timestamps[-1] - elif isinstance(end, str): + elif isinstance(end, str): # pragma: no cover end = mkdate(end) return start, end @@ -206,7 +207,7 @@ def read( ------ KeyError """ - if isinstance(timestamp, str): + if isinstance(timestamp, str): # pragma: no cover timestamp = mkdate(timestamp) if timestamp not in self.timestamps: # pragma: no cover @@ -268,13 +269,13 @@ def repurpose( """ outpath = Path(outpath) start, end = self._validate_start_end(start, end) - if outpath.exists() and overwrite: + if (outpath / "grid.nc").exists() and overwrite: shutil.rmtree(outpath) - if not outpath.exists(): # if overwrite=True, it was deleted now + if not (outpath / "grid.nc").exists(): # if overwrite=True, it was deleted now outpath.mkdir(exist_ok=True, parents=True) testimg = self._testimg() n = nimages_for_memory(testimg, memory) - if hasattr(self, "use_tqdm"): + if hasattr(self, "use_tqdm"): # pragma: no branch orig_tqdm = self.use_tqdm self.use_tqdm = False reshuffler = Img2Ts( @@ -290,7 +291,7 @@ def repurpose( imgbuffer=n, ) reshuffler.calc() - if hasattr(self, "use_tqdm"): + if hasattr(self, "use_tqdm"): # pragma: no branch self.use_tqdm = orig_tqdm reader = GriddedNcOrthoMultiTs(str(outpath), timevarname=timevarname, read_bulk=True) return reader diff --git a/src/qa4sm_preprocessing/reading/stack.py b/src/qa4sm_preprocessing/reading/stack.py index ea891d3..47fc342 100644 --- a/src/qa4sm_preprocessing/reading/stack.py +++ b/src/qa4sm_preprocessing/reading/stack.py @@ -5,10 +5,10 @@ from typing import Iterable, Union import xarray as xr -from .imagebase import XarrayImageReaderBase +from .imagebase import ImageReaderBase -class XarrayImageStackReader(XarrayImageReaderBase): +class StackImageReader(ImageReaderBase): """ Image reader that wraps a xarray.Dataset. diff --git a/src/qa4sm_preprocessing/reading/timeseries.py b/src/qa4sm_preprocessing/reading/timeseries.py index 2ec4306..f513376 100644 --- a/src/qa4sm_preprocessing/reading/timeseries.py +++ b/src/qa4sm_preprocessing/reading/timeseries.py @@ -4,15 +4,16 @@ import pandas as pd from pathlib import Path from typing import Union, Iterable, Sequence +import warnings import xarray as xr from pygeogrids.netcdf import load_grid from pynetcf.time_series import GriddedNcOrthoMultiTs as _GriddedNcOrthoMultiTs -from .base import XarrayReaderBase +from .base import ReaderBase -class XarrayTSReader(XarrayReaderBase): +class StackTs(ReaderBase): """ Wrapper for xarray.Dataset when timeseries of the data should be read. @@ -91,9 +92,6 @@ def __init__( cellsize: float = None, construct_grid: bool = True, ): - if isinstance(varnames, str): - varnames = [varnames] - varnames = list(varnames) if isinstance(ds, (str, Path)): ds = xr.open_dataset(ds) @@ -172,7 +170,8 @@ def __init__( ts_path, grid_path=None, timevarname=None, - read_bulk=True, + read_bulk=None, + kd_tree_name="pykdtree", **kwargs, ): """ @@ -186,10 +185,15 @@ def __init__( Path to grid file, that is used to organize the location of time series to read. If None is passed, grid.nc is searched for in the ts_path. - read_bulk : boolean, optional (default:False) - if set to True the data of all locations is read into memory, + read_bulk : boolean, optional (default: None) + If set to True (default) the data of all locations is read into memory, and subsequent calls to read_ts read from the cache and not from - disk this makes reading complete files faster# + disk this makes reading complete files faster. + timevarname : str, optional (default: None) + Name of the time variable to use instead of the original timestamps. + kd_tree_name : str, optional (default: "pykdtree") + Name of the Kd-tree engine used in the grid. Available options are + "pykdtree" and "scipy". Additional keyword arguments ---------------------------- @@ -211,11 +215,21 @@ def __init__( """ if grid_path is None: # pragma: no branch grid_path = os.path.join(ts_path, "grid.nc") - grid = load_grid(grid_path) - ioclass_kws = {} + grid = load_grid(grid_path, kd_tree_name=kd_tree_name) + + ioclass_kws = kwargs.get("ioclass_kws", {}) if "ioclass_kws" in kwargs: - ioclass_kws.update(kwargs["ioclass_kws"]) del kwargs["ioclass_kws"] + # if read_bulk is not given, we use the value from ioclass_kws, or True + # if this is also given. Otherwise we overwrite the value in ioclass_kws + if read_bulk is None: + read_bulk = ioclass_kws.get("read_bulk", True) + else: + if "read_bulk" in ioclass_kws and read_bulk != ioclass_kws["read_bulk"]: + warnings.warn( + f"read_bulk={read_bulk} but ioclass_kws['read_bulk']=" + f" {ioclass_kws['read_bulk']}. The first takes precedence." + ) ioclass_kws["read_bulk"] = read_bulk super().__init__(ts_path, grid, ioclass_kws=ioclass_kws, **kwargs) self.timevarname = timevarname diff --git a/src/qa4sm_preprocessing/reading/transpose.py b/src/qa4sm_preprocessing/reading/transpose.py index b165539..25da05e 100644 --- a/src/qa4sm_preprocessing/reading/transpose.py +++ b/src/qa4sm_preprocessing/reading/transpose.py @@ -39,7 +39,7 @@ def write_transposed_dataset( Parameters ---------- - reader : XarrayImageReaderBase + reader : ImageReaderBase instance Reader for the dataset. outfname : str or Path Output filename. Must end with ".nc" for netCDF output or with ".zarr" @@ -103,7 +103,7 @@ def write_transposed_dataset( ) with dask.config.set(**dask_config): _transpose(*args, **kwargs) - else: + else: # distributed is True with dask.config.set(**dask_config), Client( n_workers=1, threads_per_worker=n_threads, @@ -129,7 +129,7 @@ def _transpose( orig_cache = reader.open_dataset_kwargs.get("cache") orig_chunks = reader.open_dataset_kwargs.get("chunks") reader.open_dataset_kwargs.update({"cache": False, "chunks": None}) - if hasattr(reader, "use_tqdm"): + if hasattr(reader, "use_tqdm"): # pragma: no branch orig_tqdm = reader.use_tqdm new_tqdm = stepsize == 1 reader.use_tqdm = new_tqdm @@ -142,7 +142,13 @@ def _transpose( testds = reader._testimg() coords = dict(testds.coords) coords[reader.timename] = timestamps - dims = dict(testds.dims) + # to get the correct order of dimensions, we use the first variable and + # assert that the other variables follow the same order + testvar = testds[reader.varnames[0]] + dims = dict(zip(testvar.dims, testvar.shape)) + for var in reader.varnames: + testvar = testds[var] + assert dims == dict(zip(testvar.dims, testvar.shape)) del dims[reader.timename] dims[reader.timename] = len(timestamps) new_dimsizes = tuple(size for size in dims.values()) @@ -285,9 +291,9 @@ def _transpose( ds.to_netcdf(outfname, encoding=encoding) # restore the reader settings - if hasattr(reader, "use_tqdm"): + if hasattr(reader, "use_tqdm"): # pragma: no branch reader.use_tqdm = orig_tqdm - if hasattr(reader, "open_dataset_kwargs"): + if hasattr(reader, "open_dataset_kwargs"): # pragma: no branch reader.open_dataset_kwargs.update( {"cache": orig_cache, "chunks": orig_chunks} ) diff --git a/src/qa4sm_preprocessing/reading/write.py b/src/qa4sm_preprocessing/reading/write.py index af9d184..e2950b8 100644 --- a/src/qa4sm_preprocessing/reading/write.py +++ b/src/qa4sm_preprocessing/reading/write.py @@ -8,7 +8,7 @@ def write_images( - dataset: xr.Dataset, + dataset: Union[xr.Dataset, xr.DataArray], directory: Union[Path, str], dsname: str, fmt: str = "%Y%m%dT%H%M", @@ -41,6 +41,8 @@ def write_images( invertlats : bool, optional (default: False) Whether to ensure that the latitude axis is inverted. """ + if isinstance(dataset, xr.DataArray): + dataset = dataset.to_dataset(name=dsname) directory = Path(directory) directory.mkdir(exist_ok=True, parents=True) ntime = len(dataset[dim]) diff --git a/tests/conftest.py b/tests/conftest.py index 39a1e3d..eeb7fa5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,7 +7,7 @@ from qa4sm_preprocessing.reading import ( DirectoryImageReader, - XarrayImageStackReader, + StackImageReader, ) @@ -113,10 +113,10 @@ def synthetic_test_args(request): kwargs = {} elif request.param == "curvilinear": ds = make_curvilinear_test_dataset() - kwargs = {"curvilinear": True, "latdim": "x", "londim": "y"} + kwargs = {"curvilinear": True, "latdim": "y", "londim": "x"} elif request.param == "unstructured": ds = make_unstructured_test_dataset() - kwargs = {"locdim": "loc"} + kwargs = {"locdim": "location"} else: raise NotImplementedError return ds, kwargs @@ -151,7 +151,7 @@ def lis_noahmp_stacked(lis_noahmp_directory_image_reader): @pytest.fixture def default_xarray_reader(lis_noahmp_stacked): - return XarrayImageStackReader(lis_noahmp_stacked, "SoilMoist_inst") + return StackImageReader(lis_noahmp_stacked, "SoilMoist_inst") @pytest.fixture diff --git a/tests/test_reading/test_cli.py b/tests/test_reading/test_cli.py index 7733736..3301152 100644 --- a/tests/test_reading/test_cli.py +++ b/tests/test_reading/test_cli.py @@ -4,7 +4,7 @@ import netCDF4 import xarray as xr -from qa4sm_preprocessing.reading import GriddedNcOrthoMultiTs, XarrayTSReader +from qa4sm_preprocessing.reading import GriddedNcOrthoMultiTs, StackTs from qa4sm_preprocessing.reading.cli import repurpose, transpose # this is defined in conftest.py @@ -65,7 +65,7 @@ def test_repurpose_lis(cli_args_lis, lis_noahmp_stacked): outpath = Path(cli_args_lis[1]) repurpose(cli_args_lis) reader = GriddedNcOrthoMultiTs(outpath) - ref = XarrayTSReader(lis_noahmp_stacked, "SoilMoist_inst") + ref = StackTs(lis_noahmp_stacked, "SoilMoist_inst") assert np.all( np.sort(reader.grid.activegpis) == np.sort(ref.grid.activegpis) ) @@ -122,10 +122,10 @@ def test_repurpose_cmip(cli_args_cmip, cmip_ds): repurpose(cli_args_cmip) reader = GriddedNcOrthoMultiTs(outpath) ref_ds = ref = cmip_ds.sel(lat=slice(20, 30), lon=slice(90, 100)) - ref = XarrayTSReader(ref_ds, "mrsos") + ref = StackTs(ref_ds, "mrsos") # not comparing the grid GPIs here, because for "repurpose", the grid # started of as a global grid, from which a bbox was selected, while for - # XarrayTSReader the grid was already only points the bbox. + # StackTs the grid was already only points the bbox. _, lons, lats, _ = reader.grid.get_grid_points() for lon, lat in zip(lons, lats): ts = reader.read(lon, lat)["mrsos"] diff --git a/tests/test_reading/test_image_synthetic.py b/tests/test_reading/test_image_synthetic.py index 82326c4..f97416c 100644 --- a/tests/test_reading/test_image_synthetic.py +++ b/tests/test_reading/test_image_synthetic.py @@ -29,7 +29,7 @@ # Untested features # - get lat/lon from (start, stop, step) tuple: this is tested with the # lis-noahmp test images -# - landmask, bbox, cellsize: covered in XarrayImageStackReader tests +# - landmask, bbox, cellsize: covered in StackImageReader tests import numpy as np import pandas as pd @@ -331,7 +331,7 @@ def test_directory_image_reader_latlon_from_2d(regular_test_dataset): LON, LAT = np.meshgrid(ds.lon, ds.lat) ds["LAT"] = (["lat", "lon"], LAT) ds["LON"] = (["lat", "lon"], LON) - ds = ds.drop(["lat", "lon"]) + ds = ds.drop_vars(["lat", "lon"]) write_images(ds, test_data_path / "synthetic", "synthetic") reader = DirectoryImageReader( test_data_path / "synthetic", @@ -467,7 +467,7 @@ def _write_multistep_files(ds, directory, drop_time=False): directory.mkdir(exist_ok=True) time = ds.indexes["time"] if drop_time: - ds = ds.drop("time") + ds = ds.drop_vars("time") ds1 = ds.isel(time=slice(0, 4)) ds1.to_netcdf(directory / time[0].strftime("synthetic_%Y%m%d.nc")) ds2 = ds.isel(time=slice(4, 8)) diff --git a/tests/test_reading/test_readers_realdata.py b/tests/test_reading/test_readers_realdata.py index 0e7ece6..44d9d50 100644 --- a/tests/test_reading/test_readers_realdata.py +++ b/tests/test_reading/test_readers_realdata.py @@ -7,7 +7,7 @@ from qa4sm_preprocessing.reading import ( DirectoryImageReader, - XarrayImageStackReader, + StackImageReader, GriddedNcOrthoMultiTs, ) from qa4sm_preprocessing.reading.utils import mkdate @@ -116,7 +116,7 @@ def test_read_block(lis_noahmp_directory_image_reader): block = lis_noahmp_directory_image_reader.read_block() assert block["SoilMoist_inst"].shape == (6, 100, 50) - reader = XarrayImageStackReader(block, "SoilMoist_inst") + reader = StackImageReader(block, "SoilMoist_inst") validate_reader(reader) start_date = next(iter(lis_noahmp_directory_image_reader.timestamps)) @@ -133,7 +133,7 @@ def test_xarray_reader_basic(default_xarray_reader): def test_stack_reader_basic(cmip_ds): num_gpis = cmip_ds["mrsos"].isel(time=0).size - reader = XarrayImageStackReader(cmip_ds, "mrsos", cellsize=5.0) + reader = StackImageReader(cmip_ds, "mrsos", cellsize=5.0) assert len(reader.grid.activegpis) == num_gpis assert len(np.unique(reader.grid.activearrcell)) == 100 block = reader.read_block()["mrsos"] @@ -152,7 +152,7 @@ def test_bbox_landmask_cellsize(cmip_ds): max_lon = 100 max_lat = 30 bbox = [min_lon, min_lat, max_lon, max_lat] - reader = XarrayImageStackReader(cmip_ds, "mrsos", bbox=bbox, cellsize=5.0) + reader = StackImageReader(cmip_ds, "mrsos", bbox=bbox, cellsize=5.0) num_gpis_box = len(reader.grid.activegpis) assert num_gpis_box < num_gpis assert len(np.unique(reader.grid.activearrcell)) == 4 @@ -169,7 +169,7 @@ def test_bbox_landmask_cellsize(cmip_ds): # now additionally using a landmask landmask = ~np.isnan(cmip_ds.mrsos.isel(time=0)) - reader = XarrayImageStackReader( + reader = StackImageReader( cmip_ds, "mrsos", bbox=bbox, landmask=landmask, cellsize=5.0 ) assert len(reader.grid.activegpis) < num_gpis @@ -188,7 +188,7 @@ def test_bbox_landmask_cellsize(cmip_ds): # with landmask as variable name ds = cmip_ds ds["landmask"] = landmask - reader = XarrayImageStackReader( + reader = StackImageReader( ds, "mrsos", bbox=bbox, landmask=landmask, cellsize=5.0 ) assert len(reader.grid.activegpis) < num_gpis @@ -215,18 +215,30 @@ def test_SMOS(test_output_path): outpath = test_output_path / "SMOS_ts" ts_reader = reader.repurpose(outpath, overwrite=True, timevarname="Mean_Acq_Time") - df = ts_reader.read(ts_reader.grid.activegpis[100]) - expected_timestamps = list( - map( - mkdate, - [ - "2015-05-06T03:50:13", - "2015-05-07T03:11:27", - "2015-05-08T02:33:02", - ], + + def validate(ts_reader): + df = ts_reader.read(ts_reader.grid.activegpis[100]) + expected_timestamps = list( + map( + mkdate, + [ + "2015-05-06T03:50:13", + "2015-05-07T03:11:27", + "2015-05-08T02:33:02", + ], + ) ) - ) - expected_values = np.array([0.162236, 0.013245, np.nan]) - assert np.all(expected_timestamps == df.index) - np.testing.assert_almost_equal(expected_values, df.Soil_Moisture.values, 6) - assert np.all(df.columns == ["Soil_Moisture"]) + expected_values = np.array([0.162236, 0.013245, np.nan]) + assert np.all(expected_timestamps == df.index) + np.testing.assert_almost_equal(expected_values, df.Soil_Moisture.values, 6) + assert np.all(df.columns == ["Soil_Moisture"]) + + validate(ts_reader) + + # test overwriting again with an existing directory + ts_reader = reader.repurpose(outpath, overwrite=True, timevarname="Mean_Acq_Time") + validate(ts_reader) + + # test reading without overwriting + ts_reader = reader.repurpose(outpath, overwrite=False, timevarname="Mean_Acq_Time") + validate(ts_reader) diff --git a/tests/test_reading/test_timeseries.py b/tests/test_reading/test_timeseries.py index ea79a5d..d6c3fcb 100644 --- a/tests/test_reading/test_timeseries.py +++ b/tests/test_reading/test_timeseries.py @@ -1,21 +1,61 @@ import numpy as np +import shutil -from qa4sm_preprocessing.reading import XarrayTSReader +from qa4sm_preprocessing.reading import StackTs, GriddedNcOrthoMultiTs, StackImageReader +import pytest +from pytest import test_data_path -def test_xarray_ts_reader(regular_test_dataset): - reader = XarrayTSReader(regular_test_dataset, "X") - _, lons, lats, _ = reader.grid.get_grid_points() - for lon, lat in zip(lons, lats): - ts = reader.read(lon, lat)["X"] + +def test_StackTs(regular_test_dataset): + reader = StackTs(regular_test_dataset, "X") + gpis, lons, lats, _ = reader.grid.get_grid_points() + for gpi, lon, lat in zip(gpis, lons, lats): ref = regular_test_dataset.X.sel(lat=lat, lon=lon) + ts = reader.read(lon, lat)["X"] + assert np.all(ts == ref) + ts = reader.read(gpi)["X"] assert np.all(ts == ref) -def test_xarray_ts_reader_locdim(unstructured_test_dataset): - reader = XarrayTSReader(unstructured_test_dataset, "X", locdim="location") - gpis, _, _, _ = reader.grid.get_grid_points() - for gpi in gpis: +def test_StackTs_locdim(unstructured_test_dataset): + reader = StackTs(unstructured_test_dataset, "X", locdim="location") + gpis, lons, lats, _ = reader.grid.get_grid_points() + for gpi, lon, lat in zip(gpis, lons, lats): ts = reader.read(gpi)["X"] ref = unstructured_test_dataset.X.isel(location=gpi) assert np.all(ts == ref) + ts = reader.read(lon, lat)["X"] + assert np.all(ts == ref) + + +def test_GriddedNcOrthoMultiTs(synthetic_test_args): + + ds, kwargs = synthetic_test_args + stack = StackImageReader(ds, ["X", "Y"], **kwargs) + + tspath = test_data_path / "ts_test_path" + tsreader = stack.repurpose(tspath, overwrite=True) + + gpis, lons, lats, _ = tsreader.grid.get_grid_points() + for gpi, lon, lat in zip(gpis, lons, lats): + for var in ["X", "Y"]: + ref = ds[var].where((ds.lat == lat) & (ds.lon == lon), drop=True).squeeze() + ts = tsreader.read(gpi)[var] + assert np.all(ts == ref) + ts = tsreader.read(lon, lat)[var] + assert np.all(ts == ref) + + # manually create tsreader and test read_bulk logic + assert tsreader.ioclass_kws["read_bulk"] is True + tsreader = GriddedNcOrthoMultiTs(tspath, read_bulk=False) + assert tsreader.ioclass_kws["read_bulk"] is False + tsreader = GriddedNcOrthoMultiTs(tspath, ioclass_kws={"read_bulk": False}) + assert tsreader.ioclass_kws["read_bulk"] is False + tsreader = GriddedNcOrthoMultiTs(tspath, ioclass_kws={"read_bulk": False}, read_bulk=False) + assert tsreader.ioclass_kws["read_bulk"] is False + with pytest.warns( + UserWarning, match="read_bulk=False but" + ): + tsreader = GriddedNcOrthoMultiTs(tspath, ioclass_kws={"read_bulk": True}, read_bulk=False) + assert tsreader.ioclass_kws["read_bulk"] is False diff --git a/tests/test_reading/test_transpose.py b/tests/test_reading/test_transpose.py new file mode 100644 index 0000000..448a0aa --- /dev/null +++ b/tests/test_reading/test_transpose.py @@ -0,0 +1,52 @@ +import xarray as xr + +from qa4sm_preprocessing.reading import StackImageReader +from qa4sm_preprocessing.reading.transpose import write_transposed_dataset + +from pytest import test_data_path + +def test_write_transposed_dataset(synthetic_test_args): + + ds, kwargs = synthetic_test_args + stack = StackImageReader(ds, ["X", "Y"], **kwargs) + + transposed_path = test_data_path / "transposed.zarr" + write_transposed_dataset(stack, transposed_path) + transposed = xr.open_zarr(transposed_path, consolidated=True) + xr.testing.assert_equal(ds.transpose(..., "time"), transposed) + + +def test_write_transposed_dataset_given_chunks(synthetic_test_args): + + ds, kwargs = synthetic_test_args + stack = StackImageReader(ds, ["X", "Y"], **kwargs) + + if kwargs == {}: + chunks = {"lat": 5, "lon": 5} + elif "curvilinear" in kwargs: + chunks = {"y": 5, "x": 5} + else: + chunks = {"location": 25} + + transposed_path = test_data_path / "transposed.zarr" + write_transposed_dataset(stack, transposed_path, chunks=chunks) + transposed = xr.open_zarr(transposed_path, consolidated=True) + xr.testing.assert_equal(ds.transpose(..., "time"), transposed) + + if kwargs == {}: + assert dict(transposed.chunks) == {"lat": (5,), "lon": (5, 5), "time": (20,)} + elif "curvilinear" in kwargs: + assert dict(transposed.chunks) == {"y": (5,), "x": (5, 5), "time": (20,)} + else: + assert dict(transposed.chunks) == {"location": (25, 25), "time": (20,)} + + +def test_write_transposed_dataset_fixed_stepsize(synthetic_test_args): + + ds, kwargs = synthetic_test_args + stack = StackImageReader(ds, ["X", "Y"], **kwargs) + + transposed_path = test_data_path / "transposed.zarr" + write_transposed_dataset(stack, transposed_path, stepsize=1) + transposed = xr.open_zarr(transposed_path, consolidated=True) + xr.testing.assert_equal(ds.transpose(..., "time"), transposed) diff --git a/tests/test_reading/test_utils.py b/tests/test_reading/test_utils.py new file mode 100644 index 0000000..b6a616d --- /dev/null +++ b/tests/test_reading/test_utils.py @@ -0,0 +1,58 @@ +import datetime + +from qa4sm_preprocessing.reading.utils import * + + +def test_mkdate(): + assert mkdate("2000-12-17") == datetime.datetime(2000, 12, 17) + assert mkdate("2000-12-17T01:56") == datetime.datetime(2000, 12, 17, 1, 56) + assert mkdate("2000-12-17T01:56:33") == datetime.datetime(2000, 12, 17, 1, 56, 33) + + +def test_str2bool(): + for val in ["True", "true", "t", "T", "1", "yes", "y"]: + assert str2bool(val) + for val in ["False", "false", "f", "F", "0", "no", "n"]: + assert not str2bool(val) + +def test_nimages_for_memory(): + + nx, ny, nz = 100, 100, 10 + ds = xr.Dataset({ + "X": (("x", "y", "z"), np.random.randn(nx, ny, nz).astype(np.float64)), + "Y": (("x", "y", "z"), np.random.randn(nx, ny, nz).astype(np.float32)), + "Z": (("x", "y", "z"), np.random.randn(nx, ny, nz).astype(np.float32)), + }) + + size_X = nx*ny*nz*8 + size_Y = nx*ny*nz*4 + totalsize = size_X + 2*size_Y + + # set memory to get 8 images when using full dataset + memory = totalsize * 8 / 1024**3 + nimages = nimages_for_memory(ds, memory) + assert nimages == 8 + + # when using only X, we can get twice as many images into the same memory + nimages = nimages_for_memory(ds[["X"]], memory) + assert nimages == 16 + # same when using Y and Z + nimages = nimages_for_memory(ds[["Y", "Z"]], memory) + assert nimages == 16 + # when using only Y, 32 images should be possible + nimages = nimages_for_memory(ds[["Y"]], memory) + assert nimages == 32 + + +def test_infer_chunksizes(): + + dimsizes = (1000, 1000, 1000) + chunksizes = (100, 100, 1000) + + target_size = 8 * np.prod(chunksizes) / 1024**2 + assert infer_chunksizes(dimsizes, target_size, np.float64) == chunksizes + assert infer_chunksizes(dimsizes, target_size, 8) == chunksizes + + target_size = 4 * np.prod(chunksizes) / 1024**2 + assert infer_chunksizes(dimsizes, target_size, np.float32) == chunksizes + assert infer_chunksizes(dimsizes, target_size, 4) == chunksizes