From 5de0bcbfd241da01807c85dc45178f6d83ae5f1b Mon Sep 17 00:00:00 2001 From: mats-knmi <145579783+mats-knmi@users.noreply.github.com> Date: Mon, 12 Aug 2024 08:57:19 +0200 Subject: [PATCH] make dimension.py xarray compatible (#397) * make dimension.py xarray compatible * convert final method in the dimension module * nanmin in stead of zerovalue in square domain method * make test steps skill run * undo accidental change * remove commented out code * The dataset can contain more than one dataarray * Address pull request comments * Add links to dataset documentation everywhere --- pysteps/converters.py | 18 +- pysteps/io/importers.py | 87 ++++ pysteps/tests/helpers.py | 40 +- pysteps/tests/test_nowcasts_steps.py | 20 +- pysteps/tests/test_utils_dimension.py | 230 ++++++---- pysteps/utils/conversion.py | 21 +- pysteps/utils/dimension.py | 618 ++++++++++---------------- pysteps/utils/transformation.py | 28 +- 8 files changed, 522 insertions(+), 540 deletions(-) diff --git a/pysteps/converters.py b/pysteps/converters.py index 6c576c658..2825af612 100644 --- a/pysteps/converters.py +++ b/pysteps/converters.py @@ -12,6 +12,7 @@ """ import numpy as np +import numpy.typing as npt import pyproj import xarray as xr @@ -67,6 +68,15 @@ def _convert_proj4_to_grid_mapping(proj4str): return grid_mapping_var_name, grid_mapping_name, params +def compute_lat_lon( + x_r: npt.ArrayLike, y_r: npt.ArrayLike, projection: str +) -> tuple[npt.ArrayLike, npt.ArrayLike]: + x_2d, y_2d = np.meshgrid(x_r, y_r) + pr = pyproj.Proj(projection) + lon, lat = pr(x_2d.flatten(), y_2d.flatten(), inverse=True) + return lat.reshape(x_2d.shape), lon.reshape(x_2d.shape) + + def convert_to_xarray_dataset( precip: np.ndarray, quality: np.ndarray | None, @@ -105,9 +115,7 @@ def convert_to_xarray_dataset( if metadata["yorigin"] == "upper": y_r = np.flip(y_r) - x_2d, y_2d = np.meshgrid(x_r, y_r) - pr = pyproj.Proj(metadata["projection"]) - lon, lat = pr(x_2d.flatten(), y_2d.flatten(), inverse=True) + lat, lon = compute_lat_lon(x_r, y_r, metadata["projection"]) ( grid_mapping_var_name, @@ -166,7 +174,7 @@ def convert_to_xarray_dataset( ), "lon": ( ["y", "x"], - lon.reshape(precip.shape), + lon, { "long_name": "longitude coordinate", "standard_name": "longitude", @@ -176,7 +184,7 @@ def convert_to_xarray_dataset( ), "lat": ( ["y", "x"], - lat.reshape(precip.shape), + lat, { "long_name": "latitude coordinate", "standard_name": "latitude", diff --git a/pysteps/io/importers.py b/pysteps/io/importers.py index 5c2928203..f61d4b25b 100644 --- a/pysteps/io/importers.py +++ b/pysteps/io/importers.py @@ -65,6 +65,93 @@ | zr_b | the Z-R exponent b in Z = a*R**b | +------------------+----------------------------------------------------------+ +The data and metadata is then postprocessed into an xarray dataset. This dataset will +always contain an x and y dimension, but can be extended with a time dimension and/or +an ensemble member dimension over the course of the process. + +The dataset can contain the following coordinate variables: + + +.. tabularcolumns:: |p{2cm}|L| + ++---------------+-------------------------------------------------------------------------------------------+ +| Coordinate | Description | ++===============+===========================================================================================+ +| y | y-coordinate in Cartesian system, with units determined by ``metadata["cartesian_unit"]`` | ++---------------+-------------------------------------------------------------------------------------------+ +| x | x-coordinate in Cartesian system, with units determined by ``metadata["cartesian_unit"]`` | ++---------------+-------------------------------------------------------------------------------------------+ +| lat | latitude coordinate in degrees | ++---------------+-------------------------------------------------------------------------------------------+ +| lon | longitude coordinate in degrees | ++---------------+-------------------------------------------------------------------------------------------+ +| time | forecast time in seconds since forecast start time | ++---------------+-------------------------------------------------------------------------------------------+ +| member | ensemble member number (integer) | ++---------------+-------------------------------------------------------------------------------------------+ + + +The dataset can contain the following data variables: + +.. tabularcolumns:: |p{2cm}|L| + ++-------------------+-----------------------------------------------------------------------------------------------------------+ +| Variable | Description | ++===================+===========================================================================================================+ +| precip_intensity, | precipitation data, based on the unit the data has it is stored in one of these 3 possible variables | +| precip_accum | precip_intensity if unit is ``mm/h``, precip_accum if unit is ``mm`` and reflectivity if unit is ``dBZ``, | +| or reflectivity | the attributes of this variable contain metadata relevant to this attribute (see below) | ++-------------------+-----------------------------------------------------------------------------------------------------------+ +| quality | value between 0 and 1 denoting the quality of the precipitation data, currently not used for anything | ++-------------------+-----------------------------------------------------------------------------------------------------------+ + +Some of the metadata in the metadata dictionary is not explicitely stored in the dataset, +but is still implicitly present. For example ``x1`` can easily be found by taking the first +value from the x coordinate variable. Metadata that is not implicitly present is explicitly +stored either in the datasets global attributes or as attributes of the precipitation variable. +Data that relates to the entire dataset is stored in the global attributes. The following data +is stored in the global attributes: + +.. tabularcolumns:: |p{2cm}|L| + ++------------------+----------------------------------------------------------+ +| Key | Value | ++==================+==========================================================+ +| projection | PROJ.4-compatible projection definition | ++------------------+----------------------------------------------------------+ +| institution | name of the institution who provides the data | ++------------------+----------------------------------------------------------+ +| precip_var | the name of the precipitation variable in this dataset | ++------------------+----------------------------------------------------------+ + +The following data is stored as attributes of the precipitation variable: + +.. tabularcolumns:: |p{2cm}|L| + ++------------------+----------------------------------------------------------+ +| Key | Value | ++==================+==========================================================+ +| units | the physical unit of the data: 'mm/h', 'mm' or 'dBZ' | ++------------------+----------------------------------------------------------+ +| transform | the transformation of the data: None, 'dB', 'Box-Cox' or | +| | others | ++------------------+----------------------------------------------------------+ +| accutime | the accumulation time in minutes of the data, float | ++------------------+----------------------------------------------------------+ +| threshold | the rain/no rain threshold with the same unit, | +| | transformation and accutime of the data. | ++------------------+----------------------------------------------------------+ +| zerovalue | the value assigned to the no rain pixels with the same | +| | unit, transformation and accutime of the data. | ++------------------+----------------------------------------------------------+ +| zr_a | the Z-R constant a in Z = a*R**b | ++------------------+----------------------------------------------------------+ +| zr_b | the Z-R exponent b in Z = a*R**b | ++------------------+----------------------------------------------------------+ + +Furthermore the dataset can contain some additional metadata to make the dataset +CF-compliant. + Available Importers ------------------- diff --git a/pysteps/tests/helpers.py b/pysteps/tests/helpers.py index 68f3f7527..24c58f8d1 100644 --- a/pysteps/tests/helpers.py +++ b/pysteps/tests/helpers.py @@ -14,6 +14,7 @@ import pysteps as stp from pysteps import io, rcparams from pysteps.utils import aggregate_fields_space +from pysteps.utils.dimension import clip_domain _reference_dates = dict() _reference_dates["bom"] = datetime(2018, 6, 16, 10, 0) @@ -53,7 +54,6 @@ def get_precipitation_fields( num_prev_files=0, num_next_files=0, return_raw=False, - metadata=False, upscale=None, source="mch", log_transform=True, @@ -100,9 +100,6 @@ def get_precipitation_fields( The pre-processing steps are: 1) Convert to mm/h, 2) Mask invalid values, 3) Log-transform the data [dBR]. - metadata: bool, optional - If True, also return file metadata. - upscale: float or None, optional Upscale fields in space during the pre-processing steps. If it is None, the precipitation field is not modified. @@ -127,8 +124,8 @@ def get_precipitation_fields( Returns ------- - reference_field : array - metadata : dict + dataset: xarray.Dataset + As described in the documentation of :py:mod:`pysteps.io.importers`. """ if source == "bom": @@ -186,41 +183,34 @@ def get_precipitation_fields( # Read the radar composites importer = io.get_method(importer_name, "importer") - ref_dataset = io.read_timeseries(fns, importer, **_importer_kwargs) + dataset = io.read_timeseries(fns, importer, **_importer_kwargs) if not return_raw: - if (num_prev_files == 0) and (num_next_files == 0): - # Remove time dimension - reference_field = np.squeeze(reference_field) + precip_var = dataset.attrs["precip_var"] # Convert to mm/h - ref_dataset = stp.utils.to_rainrate(ref_dataset) + dataset = stp.utils.to_rainrate(dataset) + precip_var = dataset.attrs["precip_var"] # Clip domain - ref_dataset = stp.utils.clip_domain(ref_dataset, clip) + dataset = clip_domain(dataset, clip) # Upscale data - reference_field, ref_metadata = aggregate_fields_space( - reference_field, ref_metadata, upscale - ) + dataset = aggregate_fields_space(dataset, upscale) # Mask invalid values - reference_field = np.ma.masked_invalid(reference_field) + valid_mask = np.isfinite(dataset[precip_var].values) if log_transform: # Log-transform the data [dBR] - reference_field, ref_metadata = stp.utils.dB_transform( - reference_field, ref_metadata, threshold=0.1, zerovalue=-15.0 - ) + dataset = stp.utils.dB_transform(dataset, threshold=0.1, zerovalue=-15.0) # Set missing values with the fill value - np.ma.set_fill_value(reference_field, ref_metadata["zerovalue"]) - reference_field.data[reference_field.mask] = ref_metadata["zerovalue"] - - if metadata: - return reference_field, ref_metadata + metadata = dataset[precip_var].attrs + zerovalue = metadata["zerovalue"] + dataset[precip_var].data[~valid_mask] = zerovalue - return reference_field + return dataset def smart_assert(actual_value, expected, tolerance=None): diff --git a/pysteps/tests/test_nowcasts_steps.py b/pysteps/tests/test_nowcasts_steps.py index 61af86ba5..adb6ea917 100644 --- a/pysteps/tests/test_nowcasts_steps.py +++ b/pysteps/tests/test_nowcasts_steps.py @@ -7,7 +7,6 @@ from pysteps import io, motion, nowcasts, verification from pysteps.tests.helpers import get_precipitation_fields - steps_arg_names = ( "n_ens_members", "n_cascade_levels", @@ -44,28 +43,29 @@ def test_steps_skill( ): """Tests STEPS nowcast skill.""" # inputs - precip_input, metadata = get_precipitation_fields( + dataset_input = get_precipitation_fields( num_prev_files=2, num_next_files=0, return_raw=False, metadata=True, upscale=2000, ) - precip_input = precip_input.filled() - precip_obs = get_precipitation_fields( + dataset_obs = get_precipitation_fields( num_prev_files=0, num_next_files=3, return_raw=False, upscale=2000 - )[1:, :, :] - precip_obs = precip_obs.filled() + ).isel(time=slice(1, None, None)) + precip_var = dataset_input.attrs["precip_var"] + metadata = dataset_input[precip_var].attrs + precip_data = dataset_input[precip_var].values pytest.importorskip("cv2") oflow_method = motion.get_method("LK") - retrieved_motion = oflow_method(precip_input) + retrieved_motion = oflow_method(precip_data) nowcast_method = nowcasts.get_method("steps") precip_forecast = nowcast_method( - precip_input, + precip_data, retrieved_motion, timesteps=timesteps, precip_thr=metadata["threshold"], @@ -86,7 +86,9 @@ def test_steps_skill( timesteps if isinstance(timesteps, int) else len(timesteps) ) - crps = verification.probscores.CRPS(precip_forecast[:, -1], precip_obs[-1]) + crps = verification.probscores.CRPS( + precip_forecast[:, -1], dataset_obs[precip_var].values[-1] + ) assert crps < max_crps, f"CRPS={crps:.2f}, required < {max_crps:.2f}" diff --git a/pysteps/tests/test_utils_dimension.py b/pysteps/tests/test_utils_dimension.py index ab753ed7d..2bbb63f58 100644 --- a/pysteps/tests/test_utils_dimension.py +++ b/pysteps/tests/test_utils_dimension.py @@ -4,63 +4,89 @@ import numpy as np import pytest +import xarray as xr from numpy.testing import assert_array_equal from pytest import raises +from pysteps.converters import convert_to_xarray_dataset from pysteps.utils import dimension +fillvalues_metadata = { + "x1": 0, + "x2": 4, + "y1": 0, + "y2": 4, + "xpixelsize": 1, + "ypixelsize": 1, + "zerovalue": 0, + "yorigin": "lower", + "unit": "mm/h", + "transform": None, + "accutime": 5, + "threshold": 1.0, + "projection": "+proj=stere +lat_0=90 +lon_0=0.0 +lat_ts=60.0 +a=6378.137 +b=6356.752 +x_0=0 +y_0=0", + "zr_a": 200, + "zr_b": 1.6, + "cartesian_unit": "km", + "institution": "KNMI", +} + test_data_not_trim = ( # "data, window_size, axis, method, expected" - (np.arange(6), 2, 0, "mean", np.array([0.5, 2.5, 4.5])), + ( + np.arange(12).reshape(2, 6), + 2, + "x", + "mean", + np.array([[0.5, 2.5, 4.5], [6.5, 8.5, 10.5]]), + ), ( np.arange(4 * 6).reshape(4, 6), (2, 3), - (0, 1), + ("y", "x"), "sum", np.array([[24, 42], [96, 114]]), ), ( np.arange(4 * 6).reshape(4, 6), (2, 2), - (0, 1), + ("y", "x"), "sum", np.array([[14, 22, 30], [62, 70, 78]]), ), ( np.arange(4 * 6).reshape(4, 6), 2, - (0, 1), + ("y", "x"), "sum", np.array([[14, 22, 30], [62, 70, 78]]), ), ( np.arange(4 * 6).reshape(4, 6), (2, 3), - (0, 1), + ("y", "x"), "mean", np.array([[4.0, 7.0], [16.0, 19.0]]), ), ( np.arange(4 * 6).reshape(4, 6), (2, 2), - (0, 1), + ("y", "x"), "mean", np.array([[3.5, 5.5, 7.5], [15.5, 17.5, 19.5]]), ), ( np.arange(4 * 6).reshape(4, 6), 2, - (0, 1), + ("y", "x"), "mean", np.array([[3.5, 5.5, 7.5], [15.5, 17.5, 19.5]]), ), ) -@pytest.mark.parametrize( - "data, window_size, axis, method, expected", test_data_not_trim -) -def test_aggregate_fields(data, window_size, axis, method, expected): +@pytest.mark.parametrize("data, window_size, dim, method, expected", test_data_not_trim) +def test_aggregate_fields(data, window_size, dim, method, expected): """ Test the aggregate_fields function. The windows size must divide exactly the data dimensions. @@ -68,23 +94,25 @@ def test_aggregate_fields(data, window_size, axis, method, expected): windows size does not divide the data dimensions. The length of each dimension should be larger than 2. """ + dataset = convert_to_xarray_dataset(data, None, fillvalues_metadata) - assert_array_equal( - dimension.aggregate_fields(data, window_size, axis=axis, method=method), - expected, - ) + actual = dimension.aggregate_fields(dataset, window_size, dim=dim, method=method) + assert_array_equal(actual["precip_intensity"].values, expected) # Test the trimming capabilities. - data = np.pad(data, (0, 1)) - assert_array_equal( - dimension.aggregate_fields( - data, window_size, axis=axis, method=method, trim=True - ), - expected, + if np.ndim(window_size) == 0: + data = np.pad(data, ((0, 0), (0, 1))) + else: + data = np.pad(data, (0, 1)) + dataset = convert_to_xarray_dataset(data, None, fillvalues_metadata) + + actual = dimension.aggregate_fields( + dataset, window_size, dim=dim, method=method, trim=True ) + assert_array_equal(actual["precip_intensity"].values, expected) with raises(ValueError): - dimension.aggregate_fields(data, window_size, axis=axis, method=method) + dimension.aggregate_fields(dataset, window_size, dim=dim, method=method) def test_aggregate_fields_errors(): @@ -93,80 +121,124 @@ def test_aggregate_fields_errors(): function. """ data = np.arange(4 * 6).reshape(4, 6) + dataset = convert_to_xarray_dataset(data, None, fillvalues_metadata) with raises(ValueError): - dimension.aggregate_fields(data, -1, axis=0) + dimension.aggregate_fields(dataset, -1, dim="y") with raises(ValueError): - dimension.aggregate_fields(data, 0, axis=0) + dimension.aggregate_fields(dataset, 0, dim="y") with raises(ValueError): - dimension.aggregate_fields(data, 1, method="invalid") + dimension.aggregate_fields(dataset, 1, method="invalid") with raises(TypeError): - dimension.aggregate_fields(data, (1, 1), axis=0) + dimension.aggregate_fields(dataset, (1, 1), dim="y") # aggregate_fields_time -timestamps = [dt.datetime.now() + dt.timedelta(minutes=t) for t in range(10)] -test_data = [ +now = dt.datetime.now() +timestamps = [now + dt.timedelta(minutes=t) for t in range(10)] +test_data_time = [ ( - np.ones((10, 1, 1)), + np.ones((2, 2)), {"unit": "mm/h", "timestamps": timestamps}, 2, False, - np.ones((5, 1, 1)), + np.ones((5, 2, 2)), ), ( - np.ones((10, 1, 1)), + np.ones((2, 2)), {"unit": "mm", "timestamps": timestamps}, 2, False, - 2 * np.ones((5, 1, 1)), + 2 * np.ones((5, 2, 2)), ), ] @pytest.mark.parametrize( - "R, metadata, time_window_min, ignore_nan, expected", test_data + "data, metadata, time_window_min, ignore_nan, expected", test_data_time ) -def test_aggregate_fields_time(R, metadata, time_window_min, ignore_nan, expected): +def test_aggregate_fields_time(data, metadata, time_window_min, ignore_nan, expected): """Test the aggregate_fields_time.""" + dataset_ref = convert_to_xarray_dataset( + data, None, {**fillvalues_metadata, **metadata} + ) + datasets = [] + for timestamp in metadata["timestamps"]: + dataset_ = dataset_ref.copy(deep=True) + dataset_ = dataset_.expand_dims(dim="time", axis=0) + dataset_ = dataset_.assign_coords(time=("time", [timestamp])) + datasets.append(dataset_) + + dataset = xr.concat(datasets, dim="time") assert_array_equal( - dimension.aggregate_fields_time(R, metadata, time_window_min, ignore_nan)[0], + dimension.aggregate_fields_time(dataset, time_window_min, ignore_nan)[ + "precip_intensity" if metadata["unit"] == "mm/h" else "precip_accum" + ].values, expected, ) # aggregate_fields_space -test_data = [ +test_data_space = [ ( - np.ones((1, 10, 10)), - {"unit": "mm/h", "xpixelsize": 1, "ypixelsize": 1}, + np.ones((10, 10)), + { + "unit": "mm/h", + "x1": 0, + "x2": 10, + "y1": 0, + "y2": 10, + "xpixelsize": 1, + "ypixelsize": 1, + }, 2, False, - np.ones((1, 5, 5)), + np.ones((5, 5)), ), ( - np.ones((1, 10, 10)), - {"unit": "mm", "xpixelsize": 1, "ypixelsize": 1}, + np.ones((10, 10)), + { + "unit": "mm", + "x1": 0, + "x2": 10, + "y1": 0, + "y2": 10, + "xpixelsize": 1, + "ypixelsize": 1, + }, 2, False, - np.ones((1, 5, 5)), + np.ones((5, 5)), ), ( - np.ones((1, 10, 10)), - {"unit": "mm/h", "xpixelsize": 1, "ypixelsize": 2}, - (2, 4), + np.ones((10, 10)), + { + "unit": "mm/h", + "x1": 0, + "x2": 10, + "y1": 0, + "y2": 20, + "xpixelsize": 1, + "ypixelsize": 2, + }, + (4, 2), False, - np.ones((1, 5, 5)), + np.ones((5, 5)), ), ] -@pytest.mark.parametrize("R, metadata, space_window, ignore_nan, expected", test_data) -def test_aggregate_fields_space(R, metadata, space_window, ignore_nan, expected): +@pytest.mark.parametrize( + "data, metadata, space_window, ignore_nan, expected", test_data_space +) +def test_aggregate_fields_space(data, metadata, space_window, ignore_nan, expected): """Test the aggregate_fields_space.""" + dataset = convert_to_xarray_dataset(data, None, {**fillvalues_metadata, **metadata}) assert_array_equal( - dimension.aggregate_fields_space(R, metadata, space_window, ignore_nan)[0], + dimension.aggregate_fields_space(dataset, space_window, ignore_nan)[ + "precip_intensity" if metadata["unit"] == "mm/h" else "precip_accum" + ].values, expected, ) @@ -174,64 +246,40 @@ def test_aggregate_fields_space(R, metadata, space_window, ignore_nan, expected) # clip_domain R = np.zeros((4, 4)) R[:2, :] = 1 -test_data = [ +test_data_clip_domain = [ ( R, - { - "x1": 0, - "x2": 4, - "y1": 0, - "y2": 4, - "xpixelsize": 1, - "ypixelsize": 1, - "zerovalue": 0, - "yorigin": "upper", - }, + {"yorigin": "lower"}, None, R, ), ( R, - { - "x1": 0, - "x2": 4, - "y1": 0, - "y2": 4, - "xpixelsize": 1, - "ypixelsize": 1, - "zerovalue": 0, - "yorigin": "lower", - }, + {"yorigin": "lower"}, (2, 4, 2, 4), np.zeros((2, 2)), ), ( R, - { - "x1": 0, - "x2": 4, - "y1": 0, - "y2": 4, - "xpixelsize": 1, - "ypixelsize": 1, - "zerovalue": 0, - "yorigin": "upper", - }, + {"yorigin": "upper"}, (2, 4, 2, 4), np.ones((2, 2)), ), ] -@pytest.mark.parametrize("R, metadata, extent, expected", test_data) +@pytest.mark.parametrize("R, metadata, extent, expected", test_data_clip_domain) def test_clip_domain(R, metadata, extent, expected): """Test the clip_domain.""" - assert_array_equal(dimension.clip_domain(R, metadata, extent)[0], expected) + dataset = convert_to_xarray_dataset(R, None, {**fillvalues_metadata, **metadata}) + assert_array_equal( + dimension.clip_domain(dataset, extent)["precip_intensity"].values, expected + ) # square_domain R = np.zeros((4, 2)) -test_data = [ +test_data_square = [ # square by padding ( R, @@ -258,7 +306,7 @@ def test_clip_domain(R, metadata, extent, expected): "y2": 4, "xpixelsize": 1, "ypixelsize": 1, - "orig_domain": (4, 2), + "orig_domain": (np.array([0.5, 1.5, 2.5, 3.5]), np.array([0.5, 1.5])), "square_method": "pad", }, "pad", @@ -275,7 +323,7 @@ def test_clip_domain(R, metadata, extent, expected): "y2": 3, "xpixelsize": 1, "ypixelsize": 1, - "orig_domain": (4, 2), + "orig_domain": (np.array([0.5, 1.5, 2.5, 3.5]), np.array([0.5, 1.5])), "square_method": "crop", }, "crop", @@ -285,9 +333,15 @@ def test_clip_domain(R, metadata, extent, expected): ] -@pytest.mark.parametrize("R, metadata, method, inverse, expected", test_data) -def test_square_domain(R, metadata, method, inverse, expected): +@pytest.mark.parametrize("data, metadata, method, inverse, expected", test_data_square) +def test_square_domain(data, metadata, method, inverse, expected): """Test the square_domain.""" + dataset = convert_to_xarray_dataset(data, None, {**fillvalues_metadata, **metadata}) + dataset["precip_intensity"].attrs = { + **dataset["precip_intensity"].attrs, + **metadata, + } assert_array_equal( - dimension.square_domain(R, metadata, method, inverse)[0], expected + dimension.square_domain(dataset, method, inverse)["precip_intensity"].values, + expected, ) diff --git a/pysteps/utils/conversion.py b/pysteps/utils/conversion.py index 68228e981..2ea6a3a12 100644 --- a/pysteps/utils/conversion.py +++ b/pysteps/utils/conversion.py @@ -70,8 +70,9 @@ def to_rainrate(dataset: xr.Dataset, zr_a=None, zr_b=None): Parameters ---------- - dataset: Dataset - Dataset to be (back-)transformed. + dataset: xarray.Dataset + Dataset to be (back-)transformed as described in the documentation of + :py:mod:`pysteps.io.importers`. Additionally, in case of conversion to/from reflectivity units, the zr_a and zr_b attributes are also required, @@ -83,7 +84,7 @@ def to_rainrate(dataset: xr.Dataset, zr_a=None, zr_b=None): Returns ------- - dataset: Dataset + dataset: xarray.Dataset Dataset containing the converted units. """ @@ -159,8 +160,9 @@ def to_raindepth(dataset: xr.Dataset, zr_a=None, zr_b=None): Parameters ---------- - dataset: Dataset - Dataset to be (back-)transformed. + dataset: xarray.Dataset + Dataset to be (back-)transformed as described in the documentation of + :py:mod:`pysteps.io.importers`. Additionally, in case of conversion to/from reflectivity units, the zr_a and zr_b attributes are also required, @@ -172,7 +174,7 @@ def to_raindepth(dataset: xr.Dataset, zr_a=None, zr_b=None): Returns ------- - dataset: Dataset + dataset: xarray.Dataset Dataset containing the converted units. """ @@ -248,8 +250,9 @@ def to_reflectivity(dataset: xr.Dataset, zr_a=None, zr_b=None): Parameters ---------- - dataset: Dataset - Dataset to be (back-)transformed. + dataset: xarray.Dataset + Dataset to be (back-)transformed as described in the documentation of + :py:mod:`pysteps.io.importers`. Additionally, in case of conversion to/from reflectivity units, the zr_a and zr_b attributes are also required, @@ -261,7 +264,7 @@ def to_reflectivity(dataset: xr.Dataset, zr_a=None, zr_b=None): Returns ------- - dataset: Dataset + dataset: xarray.Dataset Dataset containing the converted units. """ diff --git a/pysteps/utils/dimension.py b/pysteps/utils/dimension.py index 43b7e2ca5..efa459610 100644 --- a/pysteps/utils/dimension.py +++ b/pysteps/utils/dimension.py @@ -14,26 +14,34 @@ clip_domain square_domain """ - import numpy as np +import xarray as xr + +from pysteps.converters import compute_lat_lon _aggregation_methods = dict( sum=np.sum, mean=np.mean, nanmean=np.nanmean, nansum=np.nansum ) -def aggregate_fields_time(R, metadata, time_window_min, ignore_nan=False): +def aggregate_fields_time( + dataset: xr.Dataset, time_window_min, ignore_nan=False +) -> xr.Dataset: """Aggregate fields in time. + It attempts to aggregate the given dataset in the time direction in an integer + number of sections of length = ``time_window_min``. + If such a aggregation is not possible, an error is raised. + The data is aggregated by a method chosen based on the unit of the precipitation + data in the dataset. ``mean`` is used when the unit is ``mm/h`` and ``sum`` + is used when the unit is ``mm``. For other units an error is raised. + Parameters ---------- - R: array-like - Array of shape (t,m,n) or (l,t,m,n) containing - a time series of (ensemble) input fields. + dataset: xarray.Dataset + Dataset containing a time series of (ensemble) input fields + as described in the documentation of :py:mod:`pysteps.io.importers`. They must be evenly spaced in time. - metadata: dict - Metadata dictionary containing the timestamps and unit attributes as - described in the documentation of :py:mod:`pysteps.io.importers`. time_window_min: float or None The length in minutes of the time window that is used to aggregate the fields. @@ -45,12 +53,8 @@ def aggregate_fields_time(R, metadata, time_window_min, ignore_nan=False): Returns ------- - outputarray: array-like - The new array of aggregated fields of shape (k,m,n) or (l,k,m,n), where - k = t*delta/time_window_min and delta is the time interval between two - successive timestamps. - metadata: dict - The metadata with updated attributes. + dataset: xarray.Dataset + The new dataset. See also -------- @@ -58,40 +62,24 @@ def aggregate_fields_time(R, metadata, time_window_min, ignore_nan=False): pysteps.utils.dimension.aggregate_fields """ - R = R.copy() - metadata = metadata.copy() - if time_window_min is None: - return R, metadata - - unit = metadata["unit"] - timestamps = metadata["timestamps"] - if "leadtimes" in metadata: - leadtimes = metadata["leadtimes"] - - if len(R.shape) < 3: - raise ValueError("The number of dimension must be > 2") - if len(R.shape) == 3: - axis = 0 - if len(R.shape) == 4: - axis = 1 - if len(R.shape) > 4: - raise ValueError("The number of dimension must be <= 4") - - if R.shape[axis] != len(timestamps): - raise ValueError( - "The list of timestamps has length %i, " % len(timestamps) - + "but R contains %i frames" % R.shape[axis] - ) + return dataset + + precip_var = dataset.attrs["precip_var"] + metadata = dataset[precip_var].attrs + + unit = metadata["units"] + + timestamps = dataset["time"].values # assumes that frames are evenly spaced - delta = (timestamps[1] - timestamps[0]).seconds / 60 + delta = (timestamps[1] - timestamps[0]) / np.timedelta64(1, "m") if delta == time_window_min: - return R, metadata - if (R.shape[axis] * delta) % time_window_min: - raise ValueError("time_window_size does not equally split R") + return dataset + if time_window_min % delta: + raise ValueError("time_window_size does not equally split dataset") - nframes = int(time_window_min / delta) + window_size = int(time_window_min / delta) # specify the operator to be used to aggregate # the values within the time window @@ -100,55 +88,47 @@ def aggregate_fields_time(R, metadata, time_window_min, ignore_nan=False): elif unit == "mm": method = "sum" else: - raise ValueError( - "can only aggregate units of 'mm/h' or 'mm'" + " not %s" % unit - ) + raise ValueError(f"can only aggregate units of 'mm/h' or 'mm' not {unit}") if ignore_nan: method = "".join(("nan", method)) - R = aggregate_fields(R, nframes, axis=axis, method=method) - - metadata["accutime"] = time_window_min - metadata["timestamps"] = timestamps[nframes - 1 :: nframes] - if "leadtimes" in metadata: - metadata["leadtimes"] = leadtimes[nframes - 1 :: nframes] + return aggregate_fields(dataset, window_size, dim="time", method=method) - return R, metadata - -def aggregate_fields_space(R, metadata, space_window, ignore_nan=False): +def aggregate_fields_space( + dataset: xr.Dataset, space_window, ignore_nan=False +) -> xr.Dataset: """ Upscale fields in space. + It attempts to aggregate the given dataset in y and x direction in an integer + number of sections of length = ``(window_size_y, window_size_x)``. + If such a aggregation is not possible, an error is raised. + The data is aggregated by computing the mean. Only datasets with precipitation + data in the ``mm`` or ``mm/h`` unit are currently supported. + Parameters ---------- - R: array-like - Array of shape (m,n), (t,m,n) or (l,t,m,n) containing a single field or - a time series of (ensemble) input fields. - metadata: dict - Metadata dictionary containing the xpixelsize, ypixelsize and unit - attributes as described in the documentation of + dataset: xarray.Dataset + Dataset containing a single field or + a time series of (ensemble) input fields as described in the documentation of :py:mod:`pysteps.io.importers`. space_window: float, tuple or None The length of the space window that is used to upscale the fields. If a float is given, the same window size is used for the x- and y-directions. Separate window sizes are used for x- and y-directions if - a two-element tuple is given. The space_window unit is the same used in - the geographical projection of R and hence the same as for the xpixelsize - and ypixelsize attributes. The space spanned by the n- and m-dimensions - of R must be a multiple of space_window. If set to None, the function - returns a copy of the original R and metadata. + a two-element tuple is given (y, x). The space_window unit is the same + as the unit of x and y in the input dataset. The space spanned by the + n- and m-dimensions of the dataset content must be a multiple of space_window. + If set to None, the function returns a copy of the original dataset. ignore_nan: bool, optional If True, ignore nan values. Returns ------- - outputarray: array-like - The new array of aggregated fields of shape (k,j), (t,k,j) or (l,t,k,j), - where k = m*ypixelsize/space_window[1] and j = n*xpixelsize/space_window[0]. - metadata: dict - The metadata with updated attributes. + dataset: xarray.Dataset + The new dataset. See also -------- @@ -156,110 +136,85 @@ def aggregate_fields_space(R, metadata, space_window, ignore_nan=False): pysteps.utils.dimension.aggregate_fields """ - R = R.copy() - metadata = metadata.copy() - if space_window is None: - return R, metadata - - unit = metadata["unit"] - ypixelsize = metadata["ypixelsize"] - xpixelsize = metadata["xpixelsize"] - - if len(R.shape) < 2: - raise ValueError("The number of dimensions must be >= 2") - if len(R.shape) == 2: - axes = [0, 1] - if len(R.shape) == 3: - axes = [1, 2] - if len(R.shape) == 4: - axes = [2, 3] - if len(R.shape) > 4: - raise ValueError("The number of dimensions must be <= 4") + return dataset + + precip_var = dataset.attrs["precip_var"] + metadata = dataset[precip_var].attrs + + unit = metadata["units"] if np.isscalar(space_window): space_window = (space_window, space_window) # assumes that frames are evenly spaced - if ypixelsize == space_window[1] and xpixelsize == space_window[0]: - return R, metadata - - ysize = R.shape[axes[0]] * ypixelsize - xsize = R.shape[axes[1]] * xpixelsize - - if ( - abs(ysize / space_window[1] - round(ysize / space_window[1])) > 1e-10 - or abs(xsize / space_window[0] - round(xsize / space_window[0])) > 1e-10 - ): - raise ValueError("space_window does not equally split R") + ydelta = dataset["y"].values[1] - dataset["y"].values[0] + xdelta = dataset["x"].values[1] - dataset["x"].values[0] - nframes = [int(space_window[1] / ypixelsize), int(space_window[0] / xpixelsize)] + if space_window[0] % ydelta > 1e-10 or space_window[1] % xdelta > 1e-10: + raise ValueError("space_window does not equally split dataset") # specify the operator to be used to aggregate the values # within the space window if unit == "mm/h" or unit == "mm": method = "mean" else: - raise ValueError( - "can only aggregate units of 'mm/h' or 'mm' " + "not %s" % unit - ) + raise ValueError(f"can only aggregate units of 'mm/h' or 'mm' not {unit}") if ignore_nan: method = "".join(("nan", method)) - R = aggregate_fields(R, nframes[0], axis=axes[0], method=method) - R = aggregate_fields(R, nframes[1], axis=axes[1], method=method) - - metadata["ypixelsize"] = space_window[1] - metadata["xpixelsize"] = space_window[0] + window_size = (int(space_window[0] / ydelta), int(space_window[1] / xdelta)) - return R, metadata + return aggregate_fields(dataset, window_size, ["y", "x"], method) -def aggregate_fields(data, window_size, axis=0, method="mean", trim=False): +def aggregate_fields( + dataset: xr.Dataset, window_size, dim="x", method="mean", trim=False +) -> xr.Dataset: """Aggregate fields along a given direction. - It attempts to aggregate the given R axis in an integer number of sections + It attempts to aggregate the given dataset dim in an integer number of sections of length = ``window_size``. If such a aggregation is not possible, an error is raised unless ``trim`` - set to True, in which case the axis is trimmed (from the end) + set to True, in which case the dim is trimmed (from the end) to make it perfectly divisible". Parameters ---------- - data: array-like - Array of any shape containing the input fields. - window_size: int or tuple of ints + dataset: xarray.Dataset + Dataset containing the input fields as described in the documentation of + :py:mod:`pysteps.io.importers`. + window_size: int or array-like of ints The length of the window that is used to aggregate the fields. If a single integer value is given, the same window is used for - all the selected axis. + all the selected dim. If ``window_size`` is a 1D array-like, each element indicates the length of the window that is used - to aggregate the fields along each axis. In this case, + to aggregate the fields along each dim. In this case, the number of elements of 'window_size' must be the same as the elements - in the ``axis`` argument. - axis: int or array-like of ints - Axis or axes where to perform the aggregation. - If this is a tuple of ints, the aggregation is performed over multiple - axes, instead of a single axis + in the ``dim`` argument. + dim: str or array-like of strs + Dim or dims where to perform the aggregation. + If this is an array-like of strs, the aggregation is performed over multiple + dims, instead of a single dim method: string, optional Optional argument that specifies the operation to use to aggregate the values within the window. Default to mean operator. trim: bool In case that the ``data`` is not perfectly divisible by - ``window_size`` along the selected axis: + ``window_size`` along the selected dim: - trim=True: the data will be trimmed (from the end) along that - axis to make it perfectly divisible. + dim to make it perfectly divisible. - trim=False: a ValueError exception is raised. Returns ------- - new_array: array-like - The new aggregated array with shape[axis] = k, - where k = R.shape[axis] / window_size. + dataset: xarray.Dataset + The new dataset. See also -------- @@ -267,90 +222,60 @@ def aggregate_fields(data, window_size, axis=0, method="mean", trim=False): pysteps.utils.dimension.aggregate_fields_space """ - if np.ndim(axis) > 1: + if np.ndim(dim) > 1: raise TypeError( "Only integers or integer 1D arrays can be used for the " "'axis' argument." ) - if np.ndim(axis) == 1: - axis = np.asarray(axis) - if np.ndim(window_size) == 0: - window_size = (window_size,) * axis.size - - window_size = np.asarray(window_size, dtype="int") - - if window_size.shape != axis.shape: - raise ValueError( - "The 'window_size' and 'axis' shapes are incompatible." - f"window_size.shape: {str(window_size.shape)}, " - f"axis.shape: {str(axis.shape)}, " - ) - - new_data = data.copy() - for i in range(axis.size): - # Recursively call the aggregate_fields function - new_data = aggregate_fields( - new_data, window_size[i], axis=axis[i], method=method, trim=trim - ) - - return new_data + if np.ndim(dim) == 0: + dim = [dim] - if np.ndim(window_size) != 0: - raise TypeError( - "A single axis was selected for the aggregation but several" - f"of window_sizes were given: {str(window_size)}." - ) + if np.ndim(window_size) == 0: + window_size = [window_size for _ in dim] - data = np.asarray(data).copy() - orig_shape = data.shape + if len(window_size) != len(dim): + raise TypeError("The length of window size does not to match the length of dim") if method not in _aggregation_methods: raise ValueError( "Aggregation method not recognized. " f"Available methods: {str(list(_aggregation_methods.keys()))}" ) + for ws in window_size: + if ws <= 0: + raise ValueError("'window_size' must be strictly positive") - if window_size <= 0: - raise ValueError("'window_size' must be strictly positive") + for d, ws in zip(dim, window_size): + if (dataset.sizes[d] % ws) and (not trim): + raise ValueError( + f"Since 'trim' argument was set to False," + f"the 'window_size' {ws} must exactly divide" + f"the dimension along the selected axis:" + f"dataset.sizes[dim]={dataset.sizes[d]}" + ) - if (orig_shape[axis] % window_size) and (not trim): - raise ValueError( - f"Since 'trim' argument was set to False," - f"the 'window_size' {window_size} must exactly divide" - f"the dimension along the selected axis:" - f"data.shape[axis]={orig_shape[axis]}" + # FIXME: The aggregation method is applied to all DataArrays in the Dataset + # Fix to allow support for an aggregation method per DataArray + return ( + dataset.rolling(dict(zip(dim, window_size))) + .reduce(_aggregation_methods[method]) + .isel( + { + d: slice(ws - 1, dataset.sizes[d] - dataset.sizes[d] % ws, ws) + for d, ws in zip(dim, window_size) + } ) - - new_data = data.swapaxes(axis, 0) - if trim: - trim_size = data.shape[axis] % window_size - if trim_size > 0: - new_data = new_data[:-trim_size] - - new_data_shape = list(new_data.shape) - new_data_shape[0] //= window_size # Final shape - - new_data = new_data.reshape(new_data_shape[0], window_size, -1) - - new_data = _aggregation_methods[method](new_data, axis=1) - - new_data = new_data.reshape(new_data_shape).swapaxes(axis, 0) - - return new_data + ) -def clip_domain(R, metadata, extent=None): +def clip_domain(dataset: xr.Dataset, extent=None): """ Clip the field domain by geographical coordinates. Parameters ---------- - R: array-like - Array of shape (m,n) or (t,m,n) containing the input fields. - metadata: dict - Metadata dictionary containing the x1, x2, y1, y2, - xpixelsize, ypixelsize, - zerovalue and yorigin attributes as described in the documentation of + dataset: xarray.Dataset + Dataset containing the input fields as described in the documentation of :py:mod:`pysteps.io.importers`. extent: scalars (left, right, bottom, top), optional The extent of the bounding box in data coordinates to be used to clip @@ -362,107 +287,48 @@ def clip_domain(R, metadata, extent=None): Returns ------- - R: array-like - the clipped array - metadata: dict - the metadata with updated attributes. + dataset: xarray.Dataset + The clipped dataset """ + if extent is None: + return dataset + return dataset.sel(x=slice(extent[0], extent[1]), y=slice(extent[2], extent[3])) - R = R.copy() - R_shape = np.array(R.shape) - metadata = metadata.copy() - if extent is None: - return R, metadata - - if len(R.shape) < 2: - raise ValueError("The number of dimension must be > 1") - if len(R.shape) == 2: - R = R[None, None, :, :] - if len(R.shape) == 3: - R = R[None, :, :, :] - if len(R.shape) > 4: - raise ValueError("The number of dimension must be <= 4") - - # extract original domain coordinates - left = metadata["x1"] - right = metadata["x2"] - bottom = metadata["y1"] - top = metadata["y2"] - - # extract bounding box coordinates - left_ = extent[0] - right_ = extent[1] - bottom_ = extent[2] - top_ = extent[3] - - # compute its extent in pixels - dim_x_ = int((right_ - left_) / metadata["xpixelsize"]) - dim_y_ = int((top_ - bottom_) / metadata["ypixelsize"]) - R_ = np.ones((R.shape[0], R.shape[1], dim_y_, dim_x_)) * metadata["zerovalue"] - - # build set of coordinates for the original domain - y_coord = ( - np.linspace(bottom, top - metadata["ypixelsize"], R.shape[2]) - + metadata["ypixelsize"] / 2.0 - ) - x_coord = ( - np.linspace(left, right - metadata["xpixelsize"], R.shape[3]) - + metadata["xpixelsize"] / 2.0 +def _pad_domain( + dataset: xr.Dataset, dim_to_pad: str, idx_buffer: int, zerovalue: float +) -> xr.Dataset: + # assumes that frames are evenly spaced + delta = dataset[dim_to_pad].values[1] - dataset[dim_to_pad].values[0] + end_values = ( + dataset[dim_to_pad].values[0] - delta * idx_buffer, + dataset[dim_to_pad].values[-1] + delta * idx_buffer, ) - # build set of coordinates for the new domain - y_coord_ = ( - np.linspace(bottom_, top_ - metadata["ypixelsize"], R_.shape[2]) - + metadata["ypixelsize"] / 2.0 + dataset_ref = dataset + + # FIXME: The same zerovalue is used for all DataArrays in the Dataset + # Fix to allow support for a zerovalue per DataArray + dataset = dataset_ref.pad({dim_to_pad: idx_buffer}, constant_values=zerovalue) + dataset[dim_to_pad] = dataset_ref[dim_to_pad].pad( + {dim_to_pad: idx_buffer}, + mode="linear_ramp", + end_values={dim_to_pad: end_values}, ) - x_coord_ = ( - np.linspace(left_, right_ - metadata["xpixelsize"], R_.shape[3]) - + metadata["xpixelsize"] / 2.0 + dataset.lat.data[:], dataset.lon.data[:] = compute_lat_lon( + dataset.x.values, dataset.y.values, dataset.attrs["projection"] ) + return dataset - # origin='upper' reverses the vertical axes direction - if metadata["yorigin"] == "upper": - y_coord = y_coord[::-1] - y_coord_ = y_coord_[::-1] - - # extract original domain - idx_y = np.where(np.logical_and(y_coord < top_, y_coord > bottom_))[0] - idx_x = np.where(np.logical_and(x_coord < right_, x_coord > left_))[0] - - # extract new domain - idx_y_ = np.where(np.logical_and(y_coord_ < top, y_coord_ > bottom))[0] - idx_x_ = np.where(np.logical_and(x_coord_ < right, x_coord_ > left))[0] - - # compose the new array - R_[:, :, idx_y_[0] : (idx_y_[-1] + 1), idx_x_[0] : (idx_x_[-1] + 1)] = R[ - :, :, idx_y[0] : (idx_y[-1] + 1), idx_x[0] : (idx_x[-1] + 1) - ] - - # update coordinates - metadata["y1"] = bottom_ - metadata["y2"] = top_ - metadata["x1"] = left_ - metadata["x2"] = right_ - R_shape[-2] = R_.shape[-2] - R_shape[-1] = R_.shape[-1] - - return R_.reshape(R_shape), metadata - - -def square_domain(R, metadata, method="pad", inverse=False): +def square_domain(dataset: xr.Dataset, method="pad", inverse=False): """ Either pad or crop a field to obtain a square domain. Parameters ---------- - R: array-like - Array of shape (m,n) or (t,m,n) containing the input fields. - metadata: dict - Metadata dictionary containing the x1, x2, y1, y2, - xpixelsize, ypixelsize, - attributes as described in the documentation of + dataset: xarray.Dataset + Dataset containing the input fields as described in the documentation of :py:mod:`pysteps.io.importers`. method: {'pad', 'crop'}, optional Either pad or crop. @@ -477,123 +343,91 @@ def square_domain(R, metadata, method="pad", inverse=False): Returns ------- - R: array-like - the reshape dataset - metadata: dict - the metadata with updated attributes. + dataset: xarray.Dataset + the reshaped dataset """ - R = R.copy() - R_shape = np.array(R.shape) - metadata = metadata.copy() - - if not inverse: - if len(R.shape) < 2: - raise ValueError("The number of dimension must be > 1") - if len(R.shape) == 2: - R = R[None, None, :] - if len(R.shape) == 3: - R = R[None, :] - if len(R.shape) > 4: - raise ValueError("The number of dimension must be <= 4") - - if R.shape[2] == R.shape[3]: - return R.squeeze() - - orig_dim = R.shape - orig_dim_n = orig_dim[0] - orig_dim_t = orig_dim[1] - orig_dim_y = orig_dim[2] - orig_dim_x = orig_dim[3] + dataset = dataset.copy(deep=True) + precip_var = dataset.attrs["precip_var"] + precip_data = dataset[precip_var].values + + x_len = len(dataset.x.values) + y_len = len(dataset.y.values) + + if inverse: + if "orig_domain" not in dataset.attrs or "square_method" not in dataset.attrs: + raise ValueError("Attempting to inverse a non squared dataset") + method = dataset.attrs.pop("square_method") + orig_domain = dataset.attrs.pop("orig_domain") if method == "pad": - new_dim = np.max(orig_dim[2:]) - R_ = np.ones((orig_dim_n, orig_dim_t, new_dim, new_dim)) * R.min() - - if orig_dim_x < new_dim: - idx_buffer = int((new_dim - orig_dim_x) / 2.0) - R_[:, :, :, idx_buffer : (idx_buffer + orig_dim_x)] = R - metadata["x1"] -= idx_buffer * metadata["xpixelsize"] - metadata["x2"] += idx_buffer * metadata["xpixelsize"] - - elif orig_dim_y < new_dim: - idx_buffer = int((new_dim - orig_dim_y) / 2.0) - R_[:, :, idx_buffer : (idx_buffer + orig_dim_y), :] = R - metadata["y1"] -= idx_buffer * metadata["ypixelsize"] - metadata["y2"] += idx_buffer * metadata["ypixelsize"] - - elif method == "crop": - new_dim = np.min(orig_dim[2:]) - R_ = np.zeros((orig_dim_n, orig_dim_t, new_dim, new_dim)) - - if orig_dim_x > new_dim: - idx_buffer = int((orig_dim_x - new_dim) / 2.0) - R_ = R[:, :, :, idx_buffer : (idx_buffer + new_dim)] - metadata["x1"] += idx_buffer * metadata["xpixelsize"] - metadata["x2"] -= idx_buffer * metadata["xpixelsize"] - - elif orig_dim_y > new_dim: - idx_buffer = int((orig_dim_y - new_dim) / 2.0) - R_ = R[:, :, idx_buffer : (idx_buffer + new_dim), :] - metadata["y1"] += idx_buffer * metadata["ypixelsize"] - metadata["y2"] -= idx_buffer * metadata["ypixelsize"] - - else: - raise ValueError("Unknown type") - - metadata["orig_domain"] = (orig_dim_y, orig_dim_x) - metadata["square_method"] = method - - R_shape[-2] = R_.shape[-2] - R_shape[-1] = R_.shape[-1] - - return R_.reshape(R_shape), metadata - - elif inverse: - if len(R.shape) < 2: - raise ValueError("The number of dimension must be > 2") - if len(R.shape) == 2: - R = R[None, None, :] - if len(R.shape) == 3: - R = R[None, :] - if len(R.shape) > 4: - raise ValueError("The number of dimension must be <= 4") - - method = metadata.pop("square_method") - shape = metadata.pop("orig_domain") - - if R.shape[2] == shape[0] and R.shape[3] == shape[1]: - return R.squeeze(), metadata - - R_ = np.zeros((R.shape[0], R.shape[1], shape[0], shape[1])) + if x_len > len(orig_domain[1]): + extent = ( + orig_domain[1].min(), + orig_domain[1].max(), + dataset.y.values.min(), + dataset.y.values.max(), + ) + elif y_len > len(orig_domain[0]): + extent = ( + dataset.x.values.min(), + dataset.x.values.max(), + orig_domain[0].min(), + orig_domain[0].max(), + ) + else: + return dataset + return clip_domain(dataset, extent) + + if method == "crop": + if x_len < len(orig_domain[1]): + dim_to_pad = "x" + idx_buffer = int((len(orig_domain[1]) - x_len) / 2.0) + elif y_len < len(orig_domain[0]): + dim_to_pad = "y" + idx_buffer = int((len(orig_domain[0]) - y_len) / 2.0) + else: + return dataset + return _pad_domain(dataset, dim_to_pad, idx_buffer, np.nanmin(precip_data)) + + raise ValueError(f"Unknown square method: {method}") + + else: + if "orig_domain" in dataset.attrs and "square_method" in dataset.attrs: + raise ValueError("Attempting to square an already squared dataset") + dataset.attrs["orig_domain"] = (dataset.y.values, dataset.x.values) + dataset.attrs["square_method"] = method if method == "pad": - if R.shape[2] == shape[0]: - idx_buffer = int((R.shape[3] - shape[1]) / 2.0) - R_ = R[:, :, :, idx_buffer : (idx_buffer + shape[1])] - metadata["x1"] += idx_buffer * metadata["xpixelsize"] - metadata["x2"] -= idx_buffer * metadata["xpixelsize"] - - elif R.shape[3] == shape[1]: - idx_buffer = int((R.shape[2] - shape[0]) / 2.0) - R_ = R[:, :, idx_buffer : (idx_buffer + shape[0]), :] - metadata["y1"] += idx_buffer * metadata["ypixelsize"] - metadata["y2"] -= idx_buffer * metadata["ypixelsize"] - - elif method == "crop": - if R.shape[2] == shape[0]: - idx_buffer = int((shape[1] - R.shape[3]) / 2.0) - R_[:, :, :, idx_buffer : (idx_buffer + R.shape[3])] = R - metadata["x1"] -= idx_buffer * metadata["xpixelsize"] - metadata["x2"] += idx_buffer * metadata["xpixelsize"] - - elif R.shape[3] == shape[1]: - idx_buffer = int((shape[0] - R.shape[2]) / 2.0) - R_[:, :, idx_buffer : (idx_buffer + R.shape[2]), :] = R - metadata["y1"] -= idx_buffer * metadata["ypixelsize"] - metadata["y2"] += idx_buffer * metadata["ypixelsize"] - - R_shape[-2] = R_.shape[-2] - R_shape[-1] = R_.shape[-1] - - return R_.reshape(R_shape), metadata + if x_len > y_len: + dim_to_pad = "y" + idx_buffer = int((x_len - y_len) / 2.0) + elif y_len > x_len: + dim_to_pad = "x" + idx_buffer = int((y_len - x_len) / 2.0) + else: + return dataset + return _pad_domain(dataset, dim_to_pad, idx_buffer, np.nanmin(precip_data)) + + if method == "crop": + if x_len > y_len: + idx_buffer = int((x_len - y_len) / 2.0) + extent = ( + dataset.x.values[idx_buffer], + dataset.x.values[-idx_buffer - 1], + dataset.y.values.min(), + dataset.y.values.max(), + ) + elif y_len > x_len: + idx_buffer = int((y_len - x_len) / 2.0) + extent = ( + dataset.x.values.min(), + dataset.x.values.max(), + dataset.y.values[idx_buffer], + dataset.y.values[-idx_buffer - 1], + ) + else: + return dataset + return clip_domain(dataset, extent) + + raise ValueError(f"Unknown square method: {method}") diff --git a/pysteps/utils/transformation.py b/pysteps/utils/transformation.py index 1977583c6..3e48fe0d8 100644 --- a/pysteps/utils/transformation.py +++ b/pysteps/utils/transformation.py @@ -41,8 +41,9 @@ def boxcox_transform( Parameters ---------- - dataset: Dataset - Dataset to be transformed. + dataset: xarray.Dataset + Dataset to be transformed as described in the documentation of + :py:mod:`pysteps.io.importers`. Lambda: float, optional Parameter Lambda of the Box-Cox transformation. It is 0 by default, which produces the log transformation. @@ -62,7 +63,7 @@ def boxcox_transform( Returns ------- - dataset: Dataset + dataset: xarray.Dataset Dataset containing the (back-)transformed units. References @@ -146,8 +147,9 @@ def dB_transform( Parameters ---------- - dataset: Dataset - Dataset to be (back-)transformed. + dataset: xarray.Dataset + Dataset to be (back-)transformed as described in the documentation of + :py:mod:`pysteps.io.importers`. threshold: float, optional Optional value that is used for thresholding with the same units as in the dataset. If None, the threshold contained in metadata is used. @@ -161,7 +163,7 @@ def dB_transform( Returns ------- - dataset: Dataset + dataset: xarray.Dataset Dataset containing the (back-)transformed units. """ @@ -223,8 +225,9 @@ def NQ_transform(dataset: xr.Dataset, inverse: bool = False, **kwargs) -> xr.Dat Parameters ---------- - dataset: Dataset - Dataset to be transformed. + dataset: xarray.Dataset + Dataset to be transformed as described in the documentation of + :py:mod:`pysteps.io.importers`. inverse: bool, optional If set to True, it performs the inverse transform. False by default. @@ -238,7 +241,7 @@ def NQ_transform(dataset: xr.Dataset, inverse: bool = False, **kwargs) -> xr.Dat Returns ------- - dataset: Dataset + dataset: xarray.Dataset Dataset containing the (back-)transformed units. References @@ -309,14 +312,15 @@ def sqrt_transform(dataset: xr.Dataset, inverse: bool = False, **kwargs) -> xr.D Parameters ---------- - dataset: Dataset - Dataset to be transformed. + dataset: xarray.Dataset + Dataset to be transformed as described in the documentation of + :py:mod:`pysteps.io.importers`. inverse: bool, optional If set to True, it performs the inverse transform. False by default. Returns ------- - dataset: Dataset + dataset: xarray.Dataset Dataset containing the (back-)transformed units. """