-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add fit_curve and predict_curve (#139)
* add prototype for fit_curve * add proper parameter parsing * use partial function in test * add assert to test * cleanups in fit_curve * minor edits * progress with parameter passing * revert changes to process decorator * add to tests * progress on predict * get predict to work * remove comment * fix up output datacube * add assertions * preserve dimension order * add cast to datetime if appropriate * keep attrs * fix typo * add ignore_nodata * bump submodule
- Loading branch information
1 parent
c4aacb2
commit 3a55003
Showing
4 changed files
with
188 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
from .curve_fitting import * | ||
from .random_forest import * |
125 changes: 125 additions & 0 deletions
125
openeo_processes_dask/process_implementations/ml/curve_fitting.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
from typing import Callable, Optional | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import xarray as xr | ||
from numpy.typing import ArrayLike | ||
|
||
from openeo_processes_dask.process_implementations.cubes import apply_dimension | ||
from openeo_processes_dask.process_implementations.data_model import RasterCube | ||
from openeo_processes_dask.process_implementations.exceptions import ( | ||
DimensionNotAvailable, | ||
) | ||
|
||
__all__ = ["fit_curve", "predict_curve"] | ||
|
||
|
||
def fit_curve( | ||
data: RasterCube, | ||
parameters: list, | ||
function: Callable, | ||
dimension: str, | ||
ignore_nodata: bool = True, | ||
): | ||
if dimension not in data.dims: | ||
raise DimensionNotAvailable( | ||
f"Provided dimension ({dimension}) not found in data.dims: {data.dims}" | ||
) | ||
|
||
dims_before = list(data.dims) | ||
|
||
# In the spec, parameters is a list, but xr.curvefit requires names for them, | ||
# so we do this to generate names locally | ||
parameters = {f"param_{i}": v for i, v in enumerate(parameters)} | ||
|
||
# The dimension along which to fit the curves cannot be chunked! | ||
rechunked_data = data.chunk({dimension: -1}) | ||
|
||
def wrapper(f): | ||
def _wrap(*args, **kwargs): | ||
return f( | ||
*args, | ||
**kwargs, | ||
positional_parameters={"x": 0, "parameters": slice(1, None)}, | ||
) | ||
|
||
return _wrap | ||
|
||
expected_dims_after = list(dims_before) | ||
expected_dims_after[dims_before.index(dimension)] = "param" | ||
|
||
# .curvefit returns some extra information that isn't required by the OpenEO process | ||
# so we simply drop these here. | ||
fit_result = ( | ||
rechunked_data.curvefit( | ||
dimension, | ||
wrapper(function), | ||
p0=parameters, | ||
param_names=list(parameters.keys()), | ||
skipna=ignore_nodata, | ||
) | ||
.drop_dims(["cov_i", "cov_j"]) | ||
.to_array() | ||
.squeeze() | ||
.transpose(*expected_dims_after) | ||
) | ||
|
||
fit_result.attrs = data.attrs | ||
|
||
return fit_result | ||
|
||
|
||
def predict_curve( | ||
parameters: RasterCube, | ||
function: Callable, | ||
dimension: str, | ||
labels: ArrayLike, | ||
): | ||
labels_were_datetime = False | ||
dims_before = list(parameters.dims) | ||
|
||
try: | ||
# Try parsing as datetime first | ||
labels = np.asarray(labels, dtype=np.datetime64) | ||
except ValueError: | ||
labels = np.asarray(labels) | ||
|
||
if np.issubdtype(labels.dtype, np.datetime64): | ||
labels = labels.astype(int) | ||
labels_were_datetime = True | ||
|
||
# This is necessary to pipe the arguments correctly through @process | ||
def wrapper(f): | ||
def _wrap(*args, **kwargs): | ||
return f( | ||
*args, | ||
positional_parameters={"parameters": 0}, | ||
named_parameters={"x": labels}, | ||
**kwargs, | ||
) | ||
|
||
return _wrap | ||
|
||
expected_dims_after = list(dims_before) | ||
expected_dims_after[dims_before.index("param")] = dimension | ||
|
||
predictions = xr.apply_ufunc( | ||
wrapper(function), | ||
parameters, | ||
vectorize=True, | ||
input_core_dims=[["param"]], | ||
output_core_dims=[[dimension]], | ||
dask="parallelized", | ||
output_dtypes=[np.float64], | ||
dask_gufunc_kwargs={ | ||
"allow_rechunk": True, | ||
"output_sizes": {dimension: len(labels)}, | ||
}, | ||
).transpose(*expected_dims_after) | ||
|
||
predictions = predictions.assign_coords({dimension: labels.data}) | ||
|
||
if labels_were_datetime: | ||
predictions[dimension] = pd.DatetimeIndex(predictions[dimension].values) | ||
|
||
return predictions |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters