diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 1f297d24d..ed8eef321 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -9,6 +9,7 @@ Contributors to this version: David Huard (:user:`huard`), Trevor James Smith (: New features and enhancements ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ * ``xclim.sdba.nbutils.quantile`` and its child functions are now faster. If the module `fastnanquantile` is installed, it is used as the backend for the computation of quantiles and yields even faster results. (:issue:`1255`, :pull:`1513`). +* New multivariate bias adjustment class `MBCn`, giving a faster and more accurate implementation of the 'MBCn' algorithm (:issue:`1551`, :pull:`1580`). Bug fixes ^^^^^^^^^ diff --git a/docs/notebooks/sdba.ipynb b/docs/notebooks/sdba.ipynb index 30b211d7b..d1030788d 100644 --- a/docs/notebooks/sdba.ipynb +++ b/docs/notebooks/sdba.ipynb @@ -439,7 +439,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Fourth example : Multivariate bias-adjustment with multiple steps (Cannon, 2018)\n", + "### Fourth example : Multivariate bias-adjustment (Cannon, 2018)\n", "\n", "This section replicates the \"MBCn\" algorithm described by [Cannon (2018)](https://doi.org/10.1007/s00382-017-3580-6). The method relies on some univariate algorithm, an adaption of the N-pdf transform of [PitiƩ et al. (2005)](https://ieeexplore.ieee.org/document/1544887/) and a final reordering step.\n", "\n", @@ -474,55 +474,16 @@ "\n", "dhist = dsim.sel(time=slice(\"1981\", \"2010\"))\n", "dsim = dsim.sel(time=slice(\"2041\", \"2070\"))\n", - "dref" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "##### Perform an initial univariate adjustment." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# additive for tasmax\n", - "QDMtx = sdba.QuantileDeltaMapping.train(\n", - " dref.tasmax, dhist.tasmax, nquantiles=20, kind=\"+\", group=\"time\"\n", - ")\n", - "# Adjust both hist and sim, we'll feed both to the Npdf transform.\n", - "scenh_tx = QDMtx.adjust(dhist.tasmax)\n", - "scens_tx = QDMtx.adjust(dsim.tasmax)\n", - "\n", - "# remove == 0 values in pr:\n", - "dref[\"pr\"] = sdba.processing.jitter_under_thresh(dref.pr, \"0.01 mm d-1\")\n", - "dhist[\"pr\"] = sdba.processing.jitter_under_thresh(dhist.pr, \"0.01 mm d-1\")\n", - "dsim[\"pr\"] = sdba.processing.jitter_under_thresh(dsim.pr, \"0.01 mm d-1\")\n", - "\n", - "# multiplicative for pr\n", - "QDMpr = sdba.QuantileDeltaMapping.train(\n", - " dref.pr, dhist.pr, nquantiles=20, kind=\"*\", group=\"time\"\n", - ")\n", - "# Adjust both hist and sim, we'll feed both to the Npdf transform.\n", - "scenh_pr = QDMpr.adjust(dhist.pr)\n", - "scens_pr = QDMpr.adjust(dsim.pr)\n", "\n", - "scenh = xr.Dataset(dict(tasmax=scenh_tx, pr=scenh_pr))\n", - "scens = xr.Dataset(dict(tasmax=scens_tx, pr=scens_pr))" + "# Stack variables : Dataset -> DataArray with `multivar` dimension\n", + "dref, dhist, dsim = (sdba.stack_variables(da) for da in (dref, dhist, dsim))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "##### Stack the variables to multivariate arrays and standardize them\n", - "The standardization process ensure the mean and standard deviation of each column (variable) is 0 and 1 respectively.\n", - "\n", - "`scenh` and `scens` are standardized together, so the two series are coherent. As we'll see further, we do not need to keep the mean and standard deviation, as we only keep the rank order information from the `NpdfTransform` output." + "##### Perform the multivariate adjustment (MBCn)." ] }, { @@ -531,91 +492,42 @@ "metadata": {}, "outputs": [], "source": [ - "# Stack the variables (tasmax and pr)\n", - "ref = sdba.processing.stack_variables(dref)\n", - "scenh = sdba.processing.stack_variables(scenh)\n", - "scens = sdba.processing.stack_variables(scens)\n", - "\n", - "# Standardize\n", - "ref, _, _ = sdba.processing.standardize(ref)\n", - "\n", - "allsim_std, _, _ = sdba.processing.standardize(xr.concat((scenh, scens), \"time\"))\n", - "scenh_std = allsim_std.sel(time=scenh.time)\n", - "scens_std = allsim_std.sel(time=scens.time)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "##### Perform the N-dimensional probability density function transform\n", + "ADJ = sdba.MBCn.train(\n", + " dref,\n", + " dhist,\n", + " base_kws={\"nquantiles\": 20, \"group\": \"time\"},\n", + " adj_kws={\"interp\": \"nearest\", \"extrapolation\": \"constant\"},\n", + " n_iter=20, # perform 20 iteration\n", + " n_escore=1000, # only send 1000 points to the escore metric\n", + ")\n", "\n", - "The NpdfTransform will iteratively randomly rotate our arrays in the \"variables\" space and apply the univariate adjustment before rotating it back. In Cannon (2018) and PitiƩ et al. (2005), it can be seen that the source array's joint distribution converges toward the target's joint distribution when a large number of iterations is done." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from xclim import set_options\n", - "\n", - "# See the advanced notebook for details on how this option work\n", - "with set_options(sdba_extra_output=True):\n", - " out = sdba.adjustment.NpdfTransform.adjust(\n", - " ref,\n", - " scenh_std,\n", - " scens_std,\n", - " base=sdba.QuantileDeltaMapping, # Use QDM as the univariate adjustment.\n", - " base_kws={\"nquantiles\": 20, \"group\": \"time\"},\n", - " n_iter=20, # perform 20 iteration\n", - " n_escore=1000, # only send 1000 points to the escore metric (it is really slow)\n", + "scenh, scens = (\n", + " ADJ.adjust(\n", + " sim=ds,\n", + " ref=dref,\n", + " hist=dhist,\n", + " base=sdba.QuantileDeltaMapping,\n", + " base_kws_vars={\n", + " \"pr\": {\n", + " \"kind\": \"*\",\n", + " \"jitter_under_thresh_value\": \"0.01 mm d-1\",\n", + " \"adapt_freq_thresh\": \"0.1 mm d-1\",\n", + " },\n", + " \"tasmax\": {\"kind\": \"+\"},\n", + " },\n", + " adj_kws={\"interp\": \"nearest\", \"extrapolation\": \"constant\"},\n", " )\n", - "\n", - "scenh_npdft = out.scenh.rename(time_hist=\"time\") # Bias-adjusted historical period\n", - "scens_npdft = out.scen # Bias-adjusted future period\n", - "extra = out.drop_vars([\"scenh\", \"scen\"])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "##### Restoring the trend\n", - "\n", - "The NpdfT has given us new \"hist\" and \"sim\" arrays with a correct rank structure. However, the trend is lost in this process. We reorder the result of the initial adjustment according to the rank structure of the NpdfT outputs to get our final bias-adjusted series.\n", - "\n", - "`sdba.processing.reordering`: 'ref' the argument that provides the order, 'sim' is the argument to reorder." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "scenh = sdba.processing.reordering(scenh_npdft, scenh, group=\"time\")\n", - "scens = sdba.processing.reordering(scens_npdft, scens, group=\"time\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "scenh = sdba.processing.unstack_variables(scenh)\n", - "scens = sdba.processing.unstack_variables(scens)" + " for ds in (dhist, dsim)\n", + ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "##### There we are!\n", + "##### Let's trigger all the computations.\n", "\n", - "Let's trigger all the computations. The use of `dask.compute` allows the three DataArrays to be computed at the same time, avoiding repeating the common steps." + "The use of `dask.compute` allows the three DataArrays to be computed at the same time, avoiding repeating the common steps." ] }, { @@ -628,9 +540,7 @@ "from dask.diagnostics import ProgressBar\n", "\n", "with ProgressBar():\n", - " scenh, scens, escores = compute(\n", - " scenh.isel(location=2), scens.isel(location=2), extra.escores.isel(location=2)\n", - " )" + " scenh, scens, escores = compute(scenh, scens, ADJ.ds.escores)" ] }, { @@ -646,13 +556,15 @@ "metadata": {}, "outputs": [], "source": [ - "fig, ax = plt.subplots()\n", - "\n", - "dref.isel(location=2).tasmax.plot(ax=ax, label=\"Reference\")\n", - "scenh.tasmax.plot(ax=ax, label=\"Adjusted\", alpha=0.65)\n", - "dhist.isel(location=2).tasmax.plot(ax=ax, label=\"Simulated\")\n", - "\n", - "ax.legend()" + "fig, axs = plt.subplots(1, 2, figsize=(16, 4))\n", + "for da, label in zip((dref, scenh, dhist), (\"Reference\", \"Adjusted\", \"Simulated\")):\n", + " ds = sdba.unstack_variables(da).isel(location=2)\n", + " # time series - tasmax\n", + " ds.tasmax.plot(ax=axs[0], label=label, alpha=0.65 if label == \"Adjusted\" else 1)\n", + " # scatter plot\n", + " ds.plot.scatter(x=\"pr\", y=\"tasmax\", ax=axs[1], label=label)\n", + "axs[0].legend()\n", + "axs[1].legend()" ] }, { @@ -661,7 +573,7 @@ "metadata": {}, "outputs": [], "source": [ - "escores.plot()\n", + "escores.isel(location=2).plot()\n", "plt.title(\"E-scores for each iteration.\")\n", "plt.xlabel(\"iteration\")\n", "plt.ylabel(\"E-score\")" @@ -686,7 +598,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.2" + "version": "3.12.3" }, "toc": { "base_numbering": 1, diff --git a/tests/test_sdba/test_adjustment.py b/tests/test_sdba/test_adjustment.py index e381bdd63..d4ddfb865 100644 --- a/tests/test_sdba/test_adjustment.py +++ b/tests/test_sdba/test_adjustment.py @@ -6,14 +6,17 @@ import xarray as xr from scipy.stats import genpareto, norm, uniform +from xclim.core.calendar import stack_periods from xclim.core.options import set_options from xclim.core.units import convert_units_to from xclim.sdba import adjustment from xclim.sdba.adjustment import ( LOCI, + BaseAdjustment, DetrendedQuantileMapping, EmpiricalQuantileMapping, ExtremeValues, + MBCn, PrincipalComponents, QuantileDeltaMapping, Scaling, @@ -35,6 +38,33 @@ from xclim.testing.sdba_utils import nancov # noqa +class TestBaseAdjustment: + def test_harmonize_units(self, series, random): + n = 10 + u = random.random(n) + da = series(u, "tas") + da2 = da.copy() + da2 = convert_units_to(da2, "degC") + (da, da2), _ = BaseAdjustment._harmonize_units(da, da2) + assert da.units == da2.units + + @pytest.mark.parametrize("use_dask", [True, False]) + def test_harmonize_units_multivariate(self, series, random, use_dask): + n = 10 + u = random.random(n) + ds = xr.merge([series(u, "tas"), series(u * 100, "pr")]) + ds2 = ds.copy() + ds2["tas"] = convert_units_to(ds2["tas"], "degC") + ds2["pr"] = convert_units_to(ds2["pr"], "nm/day") + da, da2 = stack_variables(ds), stack_variables(ds2) + if use_dask: + da, da2 = da.chunk({"multivar": 1}), da2.chunk({"multivar": 1}) + + (da, da2), _ = BaseAdjustment._harmonize_units(da, da2) + ds, ds2 = unstack_variables(da), unstack_variables(da2) + assert (ds.tas.units == ds2.tas.units) & (ds.pr.units == ds2.pr.units) + + class TestLoci: @pytest.mark.parametrize("group,dec", (["time", 2], ["time.month", 1])) def test_time_and_from_ds(self, series, group, dec, tmp_path, random): @@ -554,6 +584,50 @@ def test_add_dims(self, use_dask, open_dataset): assert scen2.sel(location=["Kugluktuk", "Vancouver"]).isnull().all() +@pytest.mark.slow +class TestMBCn: + @pytest.mark.parametrize("use_dask", [True, False]) + @pytest.mark.parametrize("group, window", [["time", 1], ["time.dayofyear", 31]]) + @pytest.mark.parametrize("period_dim", [None, "period"]) + def test_simple(self, open_dataset, use_dask, group, window, period_dim): + group, window, period_dim, use_dask = "time", 1, None, False + with set_options(sdba_encode_cf=use_dask): + if use_dask: + chunks = {"location": -1} + else: + chunks = None + ref, dsim = ( + open_dataset( + f"sdba/{file}", + chunks=chunks, + drop_variables=["lat", "lon"], + ) + .isel(location=1, drop=True) + .expand_dims(location=["Amos"]) + for file in ["ahccd_1950-2013.nc", "CanESM2_1950-2100.nc"] + ) + ref, hist = ( + ds.sel(time=slice("1981", "2010")).isel(time=slice(365 * 4)) + for ds in [ref, dsim] + ) + dsim = dsim.sel(time=slice("1981", None)) + sim = (stack_periods(dsim).isel(period=slice(1, 2))).isel( + time=slice(365 * 4) + ) + + ref, hist, sim = (stack_variables(ds) for ds in [ref, hist, sim]) + + MBCN = MBCn.train( + ref, + hist, + base_kws=dict(nquantiles=50, group=Grouper(group, window)), + adj_kws=dict(interp="linear"), + ) + p = MBCN.adjust(sim=sim, ref=ref, hist=hist, period_dim=period_dim) + # 'does it run' test + p.load() + + class TestPrincipalComponents: @pytest.mark.parametrize( "group", (Grouper("time.month"), Grouper("time", add_dims=["lon"])) @@ -601,20 +675,17 @@ def _group_assert(ds, dim): @pytest.mark.parametrize("use_dask", [True, False]) @pytest.mark.parametrize("pcorient", ["full", "simple"]) def test_real_data(self, atmosds, use_dask, pcorient): - ref = stack_variables( - xr.Dataset( - {"tasmax": atmosds.tasmax, "tasmin": atmosds.tasmin, "tas": atmosds.tas} - ) - ).isel(location=3) - hist = stack_variables( - xr.Dataset( - { - "tasmax": 1.001 * atmosds.tasmax, - "tasmin": atmosds.tasmin - 0.25, - "tas": atmosds.tas + 1, - } - ) - ).isel(location=3) + ds0 = xr.Dataset( + {"tasmax": atmosds.tasmax, "tasmin": atmosds.tasmin, "tas": atmosds.tas} + ) + ref = stack_variables(ds0).isel(location=3) + hist0 = ds0 + with xr.set_options(keep_attrs=True): + hist0["tasmax"] = 1.001 * hist0.tasmax + hist0["tasmin"] = hist0.tasmin - 0.25 + hist0["tas"] = hist0.tas + 1 + + hist = stack_variables(hist0).isel(location=3) with xr.set_options(keep_attrs=True): sim = hist + 5 sim["time"] = sim.time + np.timedelta64(10, "Y").astype(" xr.Dataset: """Train step on one group. - Notes - ----- - Dataset must contain the following variables: - ref : training target - hist : training data - Parameters ---------- ds : xr.Dataset - The dataset containing the training data. + Dataset variables: + ref : training target + hist : training data dim : str The dimension along which to compute the quantiles. kind : str The kind of correction to compute. See :py:func:`xclim.sdba.utils.get_correction`. quantiles : array-like The quantiles to compute. - adapt_freq_thresh : str | None + adapt_freq_thresh : str, optional Threshold for frequency adaptation. See :py:class:`xclim.sdba.processing.adapt_freq` for details. Default is None, meaning that frequency adaptation is not performed. + jitter_under_thresh_value : str, optional + Threshold under which to add uniform random noise to values, a quantity with units. + Default is None, meaning that jitter under thresh is not performed. Returns ------- xr.Dataset The dataset containing the adjustment factors, the quantiles over the training data, and the scaling factor. """ - hist = _adapt_freq_hist(ds, adapt_freq_thresh) if adapt_freq_thresh else ds.hist + ds["hist"] = ( + jitter_under_thresh(ds.hist, jitter_under_thresh_value) + if jitter_under_thresh_value + else ds.hist + ) + ds["hist"] = ( + _adapt_freq_hist(ds, adapt_freq_thresh) if adapt_freq_thresh else ds.hist + ) refn = u.apply_correction(ds.ref, u.invert(ds.ref.mean(dim), kind), kind) - histn = u.apply_correction(hist, u.invert(hist.mean(dim), kind), kind) + histn = u.apply_correction(ds.hist, u.invert(ds.hist.mean(dim), kind), kind) ref_q = nbu.quantile(refn, quantiles, dim) hist_q = nbu.quantile(histn, quantiles, dim) af = u.get_correction(hist_q, ref_q, kind) mu_ref = ds.ref.mean(dim) - mu_hist = hist.mean(dim) + mu_hist = ds.hist.mean(dim) scaling = u.get_correction(mu_hist, mu_ref, kind=kind) return xr.Dataset(data_vars=dict(af=af, hist_q=hist_q, scaling=scaling)) @@ -98,60 +110,376 @@ def eqm_train( kind: str, quantiles: np.ndarray, adapt_freq_thresh: str | None = None, + jitter_under_thresh_value: str | None = None, ) -> xr.Dataset: """EQM: Train step on one group. - Notes - ----- - Dataset variables: - ref : training target - hist : training data - Parameters ---------- ds : xr.Dataset - The dataset containing the training data. + Dataset variables: + ref : training target + hist : training data dim : str The dimension along which to compute the quantiles. kind : str The kind of correction to compute. See :py:func:`xclim.sdba.utils.get_correction`. quantiles : array-like The quantiles to compute. - adapt_freq_thresh : str | None + adapt_freq_thresh : str, optional Threshold for frequency adaptation. See :py:class:`xclim.sdba.processing.adapt_freq` for details. Default is None, meaning that frequency adaptation is not performed. + jitter_under_thresh_value : str, optional + Threshold under which to add uniform random noise to values, a quantity with units. + Default is None, meaning that jitter under thresh is not performed. Returns ------- xr.Dataset The dataset containing the adjustment factors and the quantiles over the training data. """ - hist = _adapt_freq_hist(ds, adapt_freq_thresh) if adapt_freq_thresh else ds.hist + ds["hist"] = ( + jitter_under_thresh(ds.hist, jitter_under_thresh_value) + if jitter_under_thresh_value + else ds.hist + ) + ds["hist"] = ( + _adapt_freq_hist(ds, adapt_freq_thresh) if adapt_freq_thresh else ds.hist + ) ref_q = nbu.quantile(ds.ref, quantiles, dim) - hist_q = nbu.quantile(hist, quantiles, dim) + hist_q = nbu.quantile(ds.hist, quantiles, dim) af = u.get_correction(hist_q, ref_q, kind) return xr.Dataset(data_vars=dict(af=af, hist_q=hist_q)) +def _npdft_train(ref, hist, rots, quantiles, method, extrap, n_escore, standardize): + r"""Npdf transform to correct a source `hist` into target `ref`. + + Perform a rotation, bias correct `hist` into `ref` with QuantileDeltaMapping, and rotate back. + Do this iteratively over all rotations `rots` and conserve adjustment factors `af_q` in each iteration. + + Notes + ----- + This function expects numpy inputs. The input arrays `ref,hist` are expected to be 2-dimensional arrays with shape: + `(len(nfeature), len(time))`, where `nfeature` is the dimension which is mixed by the multivariate bias adjustment + (e.g. a `multivar` dimension), i.e. `pts_dims[0]` in :py:func:`mbcn_train`. `rots` are rotation matrices with shape + `(len(iterations), len(nfeature), len(nfeature))`. + """ + if standardize: + ref = (ref - np.nanmean(ref, axis=-1, keepdims=True)) / ( + np.nanstd(ref, axis=-1, keepdims=True) + ) + hist = (hist - np.nanmean(hist, axis=-1, keepdims=True)) / ( + np.nanstd(hist, axis=-1, keepdims=True) + ) + af_q = np.zeros((len(rots), ref.shape[0], len(quantiles))) + escores = np.zeros(len(rots)) * np.NaN + if n_escore > 0: + ref_step, hist_step = ( + int(np.ceil(arr.shape[1] / n_escore)) for arr in [ref, hist] + ) + for ii in range(len(rots)): + rot = rots[0] if ii == 0 else rots[ii] @ rots[ii - 1].T + ref, hist = rot @ ref, rot @ hist + # loop over variables + for iv in range(ref.shape[0]): + ref_q, hist_q = nbu._quantile(ref[iv], quantiles), nbu._quantile( + hist[iv], quantiles + ) + af_q[ii, iv] = ref_q - hist_q + af = u._interp_on_quantiles_1D( + u._rank_bn(hist[iv]), + quantiles, + af_q[ii, iv], + method=method, + extrap=extrap, + ) + hist[iv] = hist[iv] + af + if n_escore > 0: + escores[ii] = nbu._escore(ref[:, ::ref_step], hist[:, ::hist_step]) + hist = rots[-1].T @ hist + return af_q, escores + + +def mbcn_train( + ds: xr.Dataset, + rot_matrices: xr.DataArray, + pts_dims: Sequence[str], + quantiles: np.ndarray, + gw_idxs: xr.DataArray, + interp: str, + extrapolation: str, + n_escore: int, +) -> xr.Dataset: + """Npdf transform training. + + Adjusting factors obtained for each rotation in the npdf transform and conserved to be applied in + the adjusting step in :py:func:`mcbn_adjust`. + + Parameters + ---------- + ds : xr.Dataset + Dataset variables: + ref : training target + hist : training data + rot_matrices : xr.DataArray + The rotation matrices as a 3D array ('iterations', , ), with shape (n_iter, , ). + pts_dims : sequence of str + The name of the "multivariate" dimension and its primed counterpart. Defaults to "multivar", which + is the normal case when using :py:func:`xclim.sdba.base.stack_variables`, and "multivar_prime", + quantiles : array-like + The quantiles to compute. + gw_idxs : xr.DataArray + Indices of the times in each windowed time group + interp : str + The interpolation method to use. + extrapolation : str + The extrapolation method to use. + n_escore : int + Number of elements to include in the e_score test (0 for all, < 0 to skip) + + Returns + ------- + xr.Dataset + The dataset containing the adjustment factors and the quantiles over the training data + (only the npdf transform of mbcn). + """ + # unpack data + ref = ds.ref + hist = ds.hist + gr_dim = gw_idxs.attrs["group_dim"] + + # npdf training core + af_q_l = [] + escores_l = [] + + # loop over time blocks + for ib in range(gw_idxs[gr_dim].size): + # indices in a given time block + indices = gw_idxs[{gr_dim: ib}].fillna(-1).astype(int).values + ind = indices[indices >= 0] + + # npdft training : multiple rotations on standardized datasets + # keep track of adjustment factors in each rotation for later use + af_q, escores = xr.apply_ufunc( + _npdft_train, + ref[{"time": ind}], + hist[{"time": ind}], + rot_matrices, + quantiles, + input_core_dims=[ + [pts_dims[0], "time"], + [pts_dims[0], "time"], + ["iterations", pts_dims[1], pts_dims[0]], + ["quantiles"], + ], + output_core_dims=[ + ["iterations", pts_dims[1], "quantiles"], + ["iterations"], + ], + dask="parallelized", + output_dtypes=[hist.dtype, hist.dtype], + kwargs={ + "method": interp, + "extrap": extrapolation, + "n_escore": n_escore, + "standardize": True, + }, + vectorize=True, + ) + af_q_l.append(af_q.expand_dims({gr_dim: [ib]})) + escores_l.append(escores.expand_dims({gr_dim: [ib]})) + af_q = xr.concat(af_q_l, dim=gr_dim) + escores = xr.concat(escores_l, dim=gr_dim) + out = xr.Dataset(dict(af_q=af_q, escores=escores)).assign_coords( + {"quantiles": quantiles, gr_dim: gw_idxs[gr_dim].values} + ) + return out + + +def _npdft_adjust(sim, af_q, rots, quantiles, method, extrap): + """Npdf transform adjusting. + + Adjusting factors `af_q` obtained in the training step are applied on the simulated data `sim` at each iterated + rotation, see :py:func:`_npdft_train`. + + This function expects numpy inputs. `sim` can be a 2-d array with shape: `(len(nfeature), len(time))`, or + a 3-d array with shape: `(len(period), len(nfeature), len(time))`, allowing to adjust multiple climatological periods + all at once. `nfeature` is the dimension which is mixed by the multivariate bias adjustment + (e.g. a `multivar` dimension), i.e. `pts_dims[0]` in :py:func:`mbcn_train`. `rots` are rotation matrices with shape + `(len(iterations), len(nfeature), len(nfeature))`. + """ + # add dummy dim if period_dim absent to uniformize the function below + # This could be done at higher level, not sure where is best + if dummy_dim_added := (len(sim.shape) == 2): + sim = sim[:, np.newaxis, :] + + # adjust npdft + for ii in range(len(rots)): + rot = rots[0] if ii == 0 else rots[ii] @ rots[ii - 1].T + sim = np.einsum("ij,j...->i...", rot, sim) + # loop over variables + for iv in range(sim.shape[0]): + af = u._interp_on_quantiles_1D_multi( + u._rank_bn(sim[iv], axis=-1), + quantiles, + af_q[ii, iv], + method=method, + extrap=extrap, + ) + sim[iv] = sim[iv] + af + + rot = rots[-1].T + sim = np.einsum("ij,j...->i...", rot, sim) + if dummy_dim_added: + sim = sim[:, 0, :] + + return sim + + +def mbcn_adjust( + ref: xr.Dataset, + hist: xr.Dataset, + sim: xr.Dataset, + ds: xr.Dataset, + pts_dims: Sequence[str], + interp: str, + extrapolation: str, + base: Callable, + base_kws_vars: dict, + adj_kws: dict, + period_dim: str | None, +) -> xr.DataArray: + """Perform the adjustment portion MBCn multivariate bias correction technique. + + The function :py:func:`mbcn_train` pre-computes the adjustment factors for each rotation + in the npdf portion of the MBCn algorithm. The rest of adjustment is performed here + in `mbcn_adjust``. + + Parameters + ---------- + ref : xr.DataArray + training target + hist : xr.DataArray + training data + sim : xr.DataArray + data to adjust (stacked with multivariate dimension) + ds : xr.Dataset + Dataset variables: + rot_matrices : Rotation matrices used in the training step. + af_q : Adjustment factors obtained in the training step for the npdf transform + g_idxs : Indices of the times in each time group + gw_idxs: Indices of the times in each windowed time group + pts_dims : [str, str] + The name of the "multivariate" dimension and its primed counterpart. Defaults to "multivar", which + is the normal case when using :py:func:`xclim.sdba.base.stack_variables`, and "multivar_prime" + interp : str + Interpolation method for the npdf transform (same as in the training step) + extrapolation : str + Extrapolation method for the npdf transform (same as in the training step) + base : BaseAdjustment + Bias-adjustment class used for the univariate bias correction. + base_kws_vars : Dict + Options for univariate training for the scenario that is reordered with the output of npdf transform. + The arguments are those expected by TrainAdjust classes along with + - kinds : Dict of correction kinds for each variable (e.g. {"pr":"*", "tasmax":"+"}) + adj_kws : Dict + Options for univariate adjust for the scenario that is reordered with the output of npdf transform + period_dim : str, optional + Name of the period dimension used when stacking time periods of `sim` using :py:func:`xclim.core.calendar.stack_periods`. + If specified, the interpolation of the npdf transform is performed only once and applied on all periods simultaneously. + This should be more performant, but also more memory intensive. Defaults to `None`: No optimization will be attempted. + + Returns + ------- + xr.Dataset + The adjusted data. + """ + # unpacking training parameters + rot_matrices = ds.rot_matrices + af_q = ds.af_q + quantiles = af_q.quantiles + g_idxs = ds.g_idxs + gw_idxs = ds.gw_idxs + gr_dim = gw_idxs.attrs["group_dim"] + win = gw_idxs.attrs["group"][1] + + # this way of handling was letting open the possibility to perform + # interpolation for multiple periods in the simulation all at once + # in principle, avoiding redundancy. Need to test this on small data + # to confirm it works, and on big data to check performance. + dims = ["time"] if period_dim is None else [period_dim, "time"] + + # mbcn core + scen_mbcn = xr.zeros_like(sim) + for ib in range(gw_idxs[gr_dim].size): + # indices in a given time block (with and without the window) + indices_gw = gw_idxs[{gr_dim: ib}].fillna(-1).astype(int).values + ind_gw = indices_gw[indices_gw >= 0] + indices_g = g_idxs[{gr_dim: ib}].fillna(-1).astype(int).values + ind_g = indices_g[indices_g >= 0] + + # 1. univariate adjustment of sim -> scen + # the kind may differ depending on the variables + scen_block = xr.zeros_like(sim[{"time": ind_gw}]) + for iv, v in enumerate(sim[pts_dims[0]].values): + sl = {"time": ind_gw, pts_dims[0]: iv} + with set_options(sdba_extra_output=False): + ADJ = base.train( + ref[sl], hist[sl], **base_kws_vars[v], skip_input_checks=True + ) + scen_block[{pts_dims[0]: iv}] = ADJ.adjust( + sim[sl], **adj_kws, skip_input_checks=True + ) + + # 2. npdft adjustment of sim + npdft_block = xr.apply_ufunc( + _npdft_adjust, + standardize(sim[{"time": ind_gw}].copy(), dim="time")[0], + af_q[{gr_dim: ib}], + rot_matrices, + quantiles, + input_core_dims=[ + [pts_dims[0]] + dims, + ["iterations", pts_dims[1], "quantiles"], + ["iterations", pts_dims[1], pts_dims[0]], + ["quantiles"], + ], + output_core_dims=[ + [pts_dims[0]] + dims, + ], + dask="parallelized", + output_dtypes=[sim.dtype], + kwargs={"method": interp, "extrap": extrapolation}, + vectorize=True, + ) + + # 3. reorder scen according to npdft results + reordered = reordering(ref=npdft_block, sim=scen_block) + if win > 1: + # keep central value of window (intersecting indices in gw_idxs and g_idxs) + scen_mbcn[{"time": ind_g}] = reordered[{"time": np.in1d(ind_gw, ind_g)}] + else: + scen_mbcn[{"time": ind_g}] = reordered + + return scen_mbcn.to_dataset(name="scen") + + @map_blocks(reduces=[Grouper.PROP, "quantiles"], scen=[]) def qm_adjust( ds: xr.Dataset, *, group: Grouper, interp: str, extrapolation: str, kind: str ) -> xr.Dataset: """QM (DQM and EQM): Adjust step on one block. - Notes - ----- - Dataset variables: - af : Adjustment factors - hist_q : Quantiles over the training data - sim : Data to adjust. - Parameters ---------- ds : xr.Dataset - The dataset containing the data to adjust. + Dataset variables: + af : Adjustment factors + hist_q : Quantiles over the training data + sim : Data to adjust. group : Grouper The grouper object. interp : str @@ -192,18 +520,14 @@ def dqm_adjust( ) -> xr.Dataset: """DQM adjustment on one block. - Notes - ----- - Dataset variables: - scaling : Scaling factor between ref and hist - af : Adjustment factors - hist_q : Quantiles over the training data - sim : Data to adjust - Parameters ---------- ds : xr.Dataset - The dataset containing the data to adjust. + Dataset variables: + scaling : Scaling factor between ref and hist + af : Adjustment factors + hist_q : Quantiles over the training data + sim : Data to adjust group : Grouper The grouper object. interp : str @@ -255,12 +579,13 @@ def dqm_adjust( def qdm_adjust(ds: xr.Dataset, *, group, interp, extrapolation, kind) -> xr.Dataset: """QDM: Adjust process on one block. - Notes - ----- - Dataset variables: - af : Adjustment factors - hist_q : Quantiles over the training data - sim : Data to adjust. + Parameters + ---------- + ds : xr.Dataset + Dataset variables: + af : Adjustment factors + hist_q : Quantiles over the training data + sim : Data to adjust. """ sim_q = group.apply(u.rank, ds.sim, main_only=True, pct=True) af = u.interp_on_quantiles( @@ -283,11 +608,12 @@ def qdm_adjust(ds: xr.Dataset, *, group, interp, extrapolation, kind) -> xr.Data def loci_train(ds: xr.Dataset, *, group, thresh) -> xr.Dataset: """LOCI: Train on one block. - Notes - ----- - Dataset variables: - ref : training target - hist : training data + Parameters + ---------- + ds : xr.Dataset + Dataset variables: + ref : training target + hist : training data """ s_thresh = group.apply( u.map_cdf, ds.rename(hist="x", ref="y"), y_value=thresh @@ -308,11 +634,12 @@ def loci_train(ds: xr.Dataset, *, group, thresh) -> xr.Dataset: def loci_adjust(ds: xr.Dataset, *, group, thresh, interp) -> xr.Dataset: """LOCI: Adjust on one block. - Notes - ----- - Dataset variables: - hist_thresh : Hist's equivalent thresh from ref - sim : Data to adjust + Parameters + ---------- + ds : xr.Dataset + Dataset variables: + hist_thresh : Hist's equivalent thresh from ref + sim : Data to adjust """ sth = u.broadcast(ds.hist_thresh, ds.sim, group=group, interp=interp) factor = u.broadcast(ds.af, ds.sim, group=group, interp=interp) @@ -328,11 +655,12 @@ def loci_adjust(ds: xr.Dataset, *, group, thresh, interp) -> xr.Dataset: def scaling_train(ds: xr.Dataset, *, dim, kind) -> xr.Dataset: """Scaling: Train on one group. - Notes - ----- - Dataset variables: - ref : training target - hist : training data + Parameters + ---------- + ds : xr.Dataset + Dataset variables: + ref : training target + hist : training data """ mhist = ds.hist.mean(dim) mref = ds.ref.mean(dim) @@ -345,11 +673,12 @@ def scaling_train(ds: xr.Dataset, *, dim, kind) -> xr.Dataset: def scaling_adjust(ds: xr.Dataset, *, group, interp, kind) -> xr.Dataset: """Scaling: Adjust on one block. - Notes - ----- - Dataset variables: - af : Adjustment factors. - sim : Data to adjust. + Parameters + ---------- + ds : xr.Dataset + Dataset variables: + af : Adjustment factors. + sim : Data to adjust. """ af = u.broadcast(ds.af, ds.sim, group=group, interp=interp) scen: xr.DataArray = u.apply_correction(ds.sim, af, kind).rename("scen") diff --git a/xclim/sdba/adjustment.py b/xclim/sdba/adjustment.py index a01fc81e5..632397f59 100644 --- a/xclim/sdba/adjustment.py +++ b/xclim/sdba/adjustment.py @@ -28,6 +28,8 @@ extremes_train, loci_adjust, loci_train, + mbcn_adjust, + mbcn_train, npdf_transform, qdm_adjust, qm_adjust, @@ -35,6 +37,7 @@ scaling_train, ) from .base import Grouper, ParametrizableWithDataset, parse_group +from .processing import grouped_time_indexes from .utils import ( ADDITIVE, best_pc_orientation_full, @@ -50,6 +53,7 @@ "DetrendedQuantileMapping", "EmpiricalQuantileMapping", "ExtremeValues", + "MBCn", "NpdfTransform", "PrincipalComponents", "QuantileDeltaMapping", @@ -121,13 +125,61 @@ def _check_inputs(cls, *inputs, group): ) @classmethod - def _harmonize_units(cls, *inputs, target: str | None = None): + def _harmonize_units(cls, *inputs, target: dict[str] | str | None = None): """Convert all inputs to the same units. If the target unit is not given, the units of the first input are used. Returns the converted inputs and the target units. """ + + def _harmonize_units_multivariate( + *inputs, dim, target: dict[str] | None = None + ): + def _convert_units_to(inda, dim, target): + varss = inda[dim].values + input_units = { + v: inda[dim].attrs["_units"][iv] for iv, v in enumerate(varss) + } + if input_units == target: + return inda + input_standard_names = { + v: inda[dim].attrs["_standard_name"][iv] + for iv, v in enumerate(varss) + } + for iv, v in enumerate(varss): + inda.attrs["units"] = input_units[v] + inda.attrs["standard_name"] = input_standard_names[v] + inda[{dim: iv}] = convert_units_to( + inda[{dim: iv}], target[v], context="infer" + ) + inda[dim].attrs["_units"][iv] = target[v] + inda.attrs["units"] = "" + inda.attrs.pop("standard_name") + return inda + + if target is None: + if "_units" not in inputs[0][dim].attrs or any( + [u is None for u in inputs[0][dim].attrs["_units"]] + ): + error_msg = ( + "Units are missing in some or all of the stacked variables." + "The dataset stacked with `stack_variables` given as input should include units for every variable." + ) + raise ValueError(error_msg) + + target = { + v: inputs[0][dim].attrs["_units"][iv] + for iv, v in enumerate(inputs[0][dim].values) + } + return ( + _convert_units_to(inda, dim=dim, target=target) for inda in inputs + ), target + + for _dim, _crd in inputs[0].coords.items(): + if _crd.attrs.get("is_variables"): + return _harmonize_units_multivariate(*inputs, dim=_dim, target=target) + if target is None: target = inputs[0].units @@ -179,12 +231,10 @@ def train(cls, ref: DataArray, hist: DataArray, **kwargs) -> TrainAdjust: skip_checks = kwargs.pop("skip_input_checks", False) if not skip_checks: - (ref, hist), train_units = cls._harmonize_units(ref, hist) - if "group" in kwargs: cls._check_inputs(ref, hist, group=kwargs["group"]) - hist = convert_units_to(hist, ref) + (ref, hist), train_units = cls._harmonize_units(ref, hist) else: train_units = "" @@ -214,12 +264,11 @@ def adjust(self, sim: DataArray, *args, **kwargs): """ skip_checks = kwargs.pop("skip_input_checks", False) if not skip_checks: - (sim, *args), _ = self._harmonize_units(sim, *args, target=self.train_units) - if "group" in self: self._check_inputs(sim, *args, group=self.group) - sim = convert_units_to(sim, self.train_units) + (sim, *args), _ = self._harmonize_units(sim, *args, target=self.train_units) + out = self._adjust(sim, *args, **kwargs) if isinstance(out, xr.DataArray): @@ -236,7 +285,12 @@ def adjust(self, sim: DataArray, *args, **kwargs): infostr = f"{str(self)}.adjust(sim, {params})" scen.attrs["history"] = update_history(f"Bias-adjusted with {infostr}", sim) scen.attrs["bias_adjustment"] = infostr - scen.attrs["units"] = self.train_units + + _is_multivariate = any( + [_crd.attrs.get("is_variables") for _crd in sim.coords.values()] + ) + if _is_multivariate is False: + scen.attrs["units"] = self.train_units if OPTIONS[SDBA_EXTRA_OUTPUT]: return out @@ -311,7 +365,12 @@ def adjust( infostr = f"{cls.__name__}.adjust(ref, hist, sim, {params})" scen.attrs["history"] = update_history(f"Bias-adjusted with {infostr}", sim) scen.attrs["bias_adjustment"] = infostr - scen.attrs["units"] = ref.units + + _is_multivariate = any( + [_crd.attrs.get("is_variables") for _crd in sim.coords.values()] + ) + if _is_multivariate is False: + scen.attrs["units"] = ref.units if OPTIONS[SDBA_EXTRA_OUTPUT]: return out @@ -370,6 +429,7 @@ def _train( kind: str = ADDITIVE, group: str | Grouper = "time", adapt_freq_thresh: str | None = None, + jitter_under_thresh_value: str | None = None, ) -> tuple[xr.Dataset, dict[str, Any]]: if np.isscalar(nquantiles): quantiles = equally_spaced_nodes(nquantiles).astype(ref.dtype) @@ -382,6 +442,7 @@ def _train( kind=kind, quantiles=quantiles, adapt_freq_thresh=adapt_freq_thresh, + jitter_under_thresh_value=jitter_under_thresh_value, ) ds.af.attrs.update( @@ -468,6 +529,7 @@ def _train( kind: str = ADDITIVE, group: str | Grouper = "time", adapt_freq_thresh: str | None = None, + jitter_under_thresh_value: str | None = None, ): if group.prop not in ["group", "dayofyear"]: warn( @@ -485,6 +547,7 @@ def _train( quantiles=quantiles, kind=kind, adapt_freq_thresh=adapt_freq_thresh, + jitter_under_thresh_value=jitter_under_thresh_value, ) ds.af.attrs.update( @@ -1204,6 +1267,274 @@ def _adjust( return out +class MBCn(TrainAdjust): + r"""Multivariate bias correction function using the N-dimensional probability density function transform. + + A multivariate bias-adjustment algorithm described by :cite:t:`sdba-cannon_multivariate_2018` + based on a color-correction algorithm described by :cite:t:`sdba-pitie_n-dimensional_2005`. + + This algorithm in itself, when used with QuantileDeltaMapping, is NOT trend-preserving. + The full MBCn algorithm includes a reordering step provided here by :py:func:`xclim.sdba.processing.reordering`. + + See notes for an explanation of the algorithm. + + Attributes + ---------- + Train step + + ref : xr.DataArray + Reference dataset. + hist : xr.DataArray + Historical dataset. + base_kws : dict, optional + Arguments passed to the training in the npdf transform. + adj_kws : dict, optional + Arguments passed to the adjusting in the npdf transform. + n_escore : int + The number of elements to send to the escore function. The default, 0, means all elements are included. + Pass -1 to skip computing the escore completely. + Small numbers result in less significant scores, but the execution time goes up quickly with large values. + n_iter : int + The number of iterations to perform. Defaults to 20. + pts_dim : str + The name of the "multivariate" dimension. Defaults to "multivar", which is the + normal case when using :py:func:`xclim.sdba.base.stack_variables`. + rot_matrices: xr.DataArray, optional + The rotation matrices as a 3D array ('iterations', , ), with shape (n_iter, , ). + If left empty, random rotation matrices will be automatically generated. + + Adjust step + + ref : xr.DataArray + Target reference dataset also needed for univariate bias correction preceding npdf transform + hist: xr.DataArray + Source dataset also needed for univariate bias correction preceding npdf transform + sim : xr.DataArray + Source dataset to adjust. + base : BaseAdjustment + Bias-adjustment class used for the univariate bias correction. + base_kws : dict, optional + Arguments passed to the training in the univariate bias correction + adj_kws : dict, optional + Arguments passed to the adjusting in the univariate bias correction + period_dim : str, optional + Name of the period dimension used when stacking time periods of `sim` using :py:func:`xclim.core.calendar.stack_periods`. + If specified, the interpolation of the npdf transform is performed only once and applied on all periods simultaneously. + This should be more performant, but also more memory intensive. + + Notes + ----- + The historical reference (:math:`T`, for "target"), simulated historical (:math:`H`) and simulated projected (:math:`S`) + datasets are constructed by stacking the timeseries of N variables together. The algorithm is broken into the + following steps: + + Training (only npdf transform training) + + 1. Standardize `ref` and `hist` (see ``xclim.sdba.processing.standardize``.) + + 2. Rotate the datasets in the N-dimensional variable space with :math:`\mathbf{R}`, a random rotation NxN matrix. + + .. math:: + + \tilde{\mathbf{T}} = \mathbf{T}\mathbf{R} \ + \tilde{\mathbf{H}} = \mathbf{H}\mathbf{R} + + 3. QuantileDeltaMapping is used to perform bias adjustment :math:`\mathcal{F}` on the rotated datasets. + The adjustment factor is conserved for later use in the adjusting step. The adjustments are made in additive mode, + for each variable :math:`i`. + + .. math:: + + \hat{\mathbf{H}}_i, \hat{\mathbf{S}}_i = \mathcal{F}\left(\tilde{\mathbf{T}}_i, \tilde{\mathbf{H}}_i, \tilde{\mathbf{S}}_i\right) + + 4. The bias-adjusted datasets are rotated back. + + .. math:: + + \mathbf{H}' = \hat{\mathbf{H}}\mathbf{R} \\ + \mathbf{S}' = \hat{\mathbf{S}}\mathbf{R} + + 5. Repeat steps 2,3,4 three steps ``n_iter`` times, i.e. the number of randomly generated rotation matrices. + + Adjusting + + 1. Perform the same steps as in training, with `ref, hist` replaced with `sim`. Step 3. of the training is modified, here we + simply reuse the adjustment factors previously found in the training step to bias correct the standardized `sim` directly. + + 2. Using the original (unstandardized) `ref,hist, sim`, perform a univariate bias adjustment using the ``base_scen`` class + on `sim`. + + 3. Reorder the dataset found in step 2. according to the ranks of the dataset found in step 1. + + + The original algorithm :cite:p:`sdba-pitie_n-dimensional_2005`, stops the iteration when some distance score converges. + Following cite:t:`sdba-cannon_multivariate_2018` and the MBCn implementation in :cite:t:`sdba-cannon_mbc_2020`, we + instead fix the number of iterations. + + As done by cite:t:`sdba-cannon_multivariate_2018`, the distance score chosen is the "Energy distance" from + :cite:t:`sdba-szekely_testing_2004`. (see: :py:func:`xclim.sdba.processing.escore`). + + The random matrices are generated following a method laid out by :cite:t:`sdba-mezzadri_how_2007`. + + References + ---------- + :cite:cts:`sdba-cannon_multivariate_2018,sdba-cannon_mbc_2020,sdba-pitie_n-dimensional_2005,sdba-mezzadri_how_2007,sdba-szekely_testing_2004` + + Notes + ----- + Only "time" and "time.dayofyear" (with a suitable window) are implemented as possible values for `group`. + """ + + @classmethod + def _train( + cls, + ref: xr.DataArray, + hist: xr.DataArray, + *, + base_kws: dict[str, Any] | None = None, + adj_kws: dict[str, Any] | None = None, + n_escore: int = -1, + n_iter: int = 20, + pts_dim: str = "multivar", + rot_matrices: xr.DataArray | None = None, + ): + # set default values for non-specified parameters + base_kws = base_kws if base_kws is not None else {} + adj_kws = adj_kws if adj_kws is not None else {} + base_kws.setdefault("nquantiles", 20) + base_kws.setdefault("group", Grouper("time", 1)) + adj_kws.setdefault("interp", "nearest") + adj_kws.setdefault("extrapolation", "constant") + + if np.isscalar(base_kws["nquantiles"]): + base_kws["nquantiles"] = equally_spaced_nodes(base_kws["nquantiles"]) + if isinstance(base_kws["group"], str): + base_kws["group"] = Grouper(base_kws["group"], 1) + if base_kws["group"].name == "time.month": + NotImplementedError( + "Received `group==time.month` in `base_kws`. Monthly grouping is not currently supported in the MBCn class." + ) + # stack variables and prepare rotations + if rot_matrices is not None: + if pts_dim != rot_matrices.attrs["crd_dim"]: + raise ValueError( + f"`crd_dim` attribute of `rot_matrices` ({rot_matrices.attrs['crd_dim']}) does not correspond to `pts_dim` ({pts_dim})." + ) + else: + rot_dim = xr.core.utils.get_temp_dimname( + set(ref.dims).union(hist.dims), pts_dim + "_prime" + ) + rot_matrices = rand_rot_matrix( + ref[pts_dim], num=n_iter, new_dim=rot_dim + ).rename(matrices="iterations") + pts_dims = [rot_matrices.attrs[d] for d in ["crd_dim", "new_dim"]] + + # time indices corresponding to group and windowed group + # used to divide datasets as map_blocks or groupby would do + _, gw_idxs = grouped_time_indexes(ref.time, base_kws["group"]) + + # training, obtain adjustment factors of the npdf transform + ds = xr.Dataset(dict(ref=ref, hist=hist)) + params = { + "quantiles": base_kws["nquantiles"], + "interp": adj_kws["interp"], + "extrapolation": adj_kws["extrapolation"], + "pts_dims": pts_dims, + "n_escore": n_escore, + } + out = mbcn_train(ds, rot_matrices=rot_matrices, gw_idxs=gw_idxs, **params) + params["group"] = base_kws["group"] + + # postprocess + out["rot_matrices"] = rot_matrices + + out.af_q.attrs.update( + standard_name="Adjustment factors", + long_name="Quantile mapping adjustment factors", + ) + return out, params + + def _adjust( + self, + sim: xr.DataArray, + ref: xr.DataArray, + hist: xr.DataArray, + *, + base: TrainAdjust = QuantileDeltaMapping, + base_kws_vars: dict[str, Any] | None = None, + adj_kws: dict[str, Any] | None = None, + period_dim=None, + ): + # set default values for non-specified parameters + base_kws_vars = base_kws_vars or {} + pts_dim = self.pts_dims[0] + for v in sim[pts_dim].values: + base_kws_vars.setdefault(v, {}) + base_kws_vars[v].setdefault("group", self.group) + if isinstance(base_kws_vars[v]["group"], str): + base_kws_vars[v]["group"] = Grouper(base_kws_vars[v]["group"], 1) + if base_kws_vars[v]["group"] != self.group: + raise ValueError( + f"`group` input in _train and _adjust must be the same." + f"Got {self.group} and {base_kws_vars[v]['group']}" + ) + base_kws_vars[v].pop("group") + + base_kws_vars[v].setdefault("nquantiles", self.ds.af_q.quantiles.values) + if np.isscalar(base_kws_vars[v]["nquantiles"]): + base_kws_vars[v]["nquantiles"] = equally_spaced_nodes( + base_kws_vars[v]["nquantiles"] + ) + if "is_variables" in sim[pts_dim].attrs: + if self.train_units == "": + _, units = self._harmonize_units(sim) + else: + units = self.train_units + + if "jitter_under_thresh_value" in base_kws_vars[v]: + base_kws_vars[v]["jitter_under_thresh_value"] = str( + convert_units_to( + base_kws_vars[v]["jitter_under_thresh_value"], + units[v], + context="hydro", + ) + ) + if "adapt_freq_thresh" in base_kws_vars[v]: + base_kws_vars[v]["adapt_freq_thresh"] = str( + convert_units_to( + base_kws_vars[v]["adapt_freq_thresh"], + units[v], + context="hydro", + ) + ) + + adj_kws = adj_kws or {} + adj_kws.setdefault("interp", self.interp) + adj_kws.setdefault("extrapolation", self.extrapolation) + + g_idxs, gw_idxs = grouped_time_indexes(ref.time, self.group) + ds = self.ds.copy() + ds["g_idxs"] = g_idxs + ds["gw_idxs"] = gw_idxs + + # adjust (adjust for npft transform, train/adjust for univariate bias correction) + out = mbcn_adjust( + ref=ref, + hist=hist, + sim=sim, + ds=ds, + pts_dims=self.pts_dims, + interp=self.interp, + extrapolation=self.extrapolation, + base=base, + base_kws_vars=base_kws_vars, + adj_kws=adj_kws, + period_dim=period_dim, + ) + + return out + + try: import SBCK except ImportError: # noqa: S110 diff --git a/xclim/sdba/nbutils.py b/xclim/sdba/nbutils.py index 29a1ae9ad..fe942b11d 100644 --- a/xclim/sdba/nbutils.py +++ b/xclim/sdba/nbutils.py @@ -208,7 +208,8 @@ def _wrapper_quantile1d(arr, q): return out -def _quantile(arr, q, nreduce): +def _quantile(arr, q, nreduce=None): + nreduce = nreduce or arr.ndim if arr.ndim == nreduce: out = _nan_quantile_1d(arr.flatten(), q) else: diff --git a/xclim/sdba/processing.py b/xclim/sdba/processing.py index ecd2128f6..d7d023d5d 100644 --- a/xclim/sdba/processing.py +++ b/xclim/sdba/processing.py @@ -28,6 +28,7 @@ "adapt_freq", "escore", "from_additive_space", + "grouped_time_indexes", "jitter", "jitter_over_thresh", "jitter_under_thresh", @@ -822,3 +823,65 @@ def unstack_variables(da: xr.DataArray, dim: str | None = None) -> xr.Dataset: ds[var.item()].attrs[name[1:]] = attr return ds + + +def grouped_time_indexes(times, group): + """Time indexes for every group blocks + + Time indexes can be used to implement a pseudo-"numpy.groupies" approach to grouping. + + Parameters + ---------- + times : xr.DataArray + Time dimension in the dataset of interest. + group : str or Grouper + Grouping information, see base.Grouper + + Returns + ------- + g_idxs : xr.DataArray + Time indexes of the blocks (only using `group.name` and not `group.window`). + gw_idxs : xr.DataArray + Time indexes of the blocks (built with a rolling window of `group.window` if any). + """ + + def _get_group_complement(da, group): + # complement of "dayofyear": "year", etc. + gr = group if isinstance(group, str) else group.name + if gr == "time.dayofyear": + return da.time.dt.year + if gr == "time.month": + return da.time.dt.strftime("%Y-%d") + + # does not work with group == "time.month" + group = group if isinstance(group, Grouper) else Grouper(group) + gr, win = group.name, group.window + # get time indices (0,1,2,...) for each block + timeind = xr.DataArray(np.arange(times.size), coords={"time": times}) + win_dim0, win_dim = ( + get_temp_dimname(timeind.dims, lab) for lab in ["win_dim0", "win_dim"] + ) + if gr == "time.dayofyear": + # time indices for each block with window = 1 + g_idxs = timeind.groupby(gr).apply( + lambda da: da.assign_coords(time=_get_group_complement(da, gr)).rename( + {"time": "year"} + ) + ) + # time indices for each block with general window + da = timeind.rolling(time=win, center=True).construct(window_dim=win_dim0) + gw_idxs = da.groupby(gr).apply( + lambda da: da.assign_coords(time=_get_group_complement(da, gr)).stack( + {win_dim: ["time", win_dim0]} + ) + ) + gw_idxs = gw_idxs.transpose(..., win_dim) + elif gr == "time": + gw_idxs = timeind.rename({"time": win_dim}).expand_dims({win_dim0: [-1]}) + g_idxs = gw_idxs.copy() + else: + raise NotImplementedError(f"Grouping {gr} not implemented.") + gw_idxs.attrs["group"] = (gr, win) + gw_idxs.attrs["time_dim"] = win_dim + gw_idxs.attrs["group_dim"] = [d for d in g_idxs.dims if d != win_dim][0] + return g_idxs, gw_idxs diff --git a/xclim/sdba/utils.py b/xclim/sdba/utils.py index 89f451fba..512e7e8b4 100644 --- a/xclim/sdba/utils.py +++ b/xclim/sdba/utils.py @@ -9,6 +9,7 @@ from typing import Callable from warnings import warn +import bottleneck as bn import numpy as np import xarray as xr from boltons.funcutils import wraps @@ -310,6 +311,39 @@ def add_cyclic_bounds( return ensure_chunk_size(qmf, **{att: -1}) +def _interp_on_quantiles_1D_multi(newxs, oldx, oldy, method, extrap): # noqa: N802 + # Perform multiple interpolations with a single call of interp1d. + # This should be used when `oldx` is common for many data arrays (`newxs`) + # that we want to interpolate on. For instance, with QuantileDeltaMapping, we simply + # interpolate on quantiles that always remain the same. + if len(newxs.shape) == 1: + return _interp_on_quantiles_1D(newxs, oldx, oldy, method, extrap) + mask_old = np.isnan(oldy) | np.isnan(oldx) + if extrap == "constant": + fill_value = ( + oldy[~np.isnan(oldy)][0], + oldy[~np.isnan(oldy)][-1], + ) + else: # extrap == 'nan' + fill_value = np.NaN + + finterp1d = interp1d( + oldx[~mask_old], + oldy[~mask_old], + kind=method, + bounds_error=False, + fill_value=fill_value, + ) + + out = np.zeros_like(newxs) + for ii in range(newxs.shape[0]): + mask_new = np.isnan(newxs[ii, :]) + y1 = newxs[ii, :].copy() * np.NaN + y1[~mask_new] = finterp1d(newxs[ii, ~mask_new]) + out[ii, :] = y1.flatten() + return out + + def _interp_on_quantiles_1D(newx, oldx, oldy, method, extrap): # noqa: N802 mask_new = np.isnan(newx) mask_old = np.isnan(oldy) | np.isnan(oldx) @@ -328,7 +362,6 @@ def _interp_on_quantiles_1D(newx, oldx, oldy, method, extrap): # noqa: N802 ) else: # extrap == 'nan' fill_value = np.NaN - out[~mask_new] = interp1d( oldx[~mask_old], oldy[~mask_old], @@ -532,6 +565,14 @@ def rank( return rnk +def _rank_bn(arr, axis=None): + """Ranking on a specific axis""" + rnk = bn.nanrankdata(arr, axis=axis) + rnk = rnk / np.nanmax(rnk, axis=axis, keepdims=True) + mx, mn = 1, np.nanmin(rnk, axis=axis, keepdims=True) + return mx * (rnk - mn) / (mx - mn) + + def pc_matrix(arr: np.ndarray | dsk.Array) -> np.ndarray | dsk.Array: """Construct a Principal Component matrix. @@ -856,9 +897,11 @@ def rand_rot_matrix( num = np.diag(R) denum = np.abs(num) lam = np.diag(num / denum) # "lambda" - return xr.DataArray( - Q @ lam, dims=(dim, new_dim), coords={dim: crd, new_dim: crd2} - ).astype("float32") + return ( + xr.DataArray(Q @ lam, dims=(dim, new_dim), coords={dim: crd, new_dim: crd2}) + .astype("float32") + .assign_attrs({"crd_dim": dim, "new_dim": new_dim}) + ) def copy_all_attrs(ds: xr.Dataset | xr.DataArray, ref: xr.Dataset | xr.DataArray):