diff --git a/xarray/__init__.py b/xarray/__init__.py index e3b7ec469e9..b49ab1848b7 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -30,6 +30,7 @@ where, ) from xarray.core.concat import concat +from xarray.core.coordinate_transform import CoordinateTransform from xarray.core.coordinates import Coordinates from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset @@ -109,6 +110,7 @@ "CFTimeIndex", "Context", "Coordinates", + "CoordinateTransform", "DataArray", "Dataset", "DataTree", diff --git a/xarray/core/coordinate_transform.py b/xarray/core/coordinate_transform.py new file mode 100644 index 00000000000..40043da46bc --- /dev/null +++ b/xarray/core/coordinate_transform.py @@ -0,0 +1,78 @@ +from collections.abc import Hashable, Iterable, Mapping +from typing import Any + +import numpy as np + + +class CoordinateTransform: + """Abstract coordinate transform with dimension & coordinate names.""" + + coord_names: tuple[Hashable, ...] + dims: tuple[str, ...] + dim_size: dict[str, int] + dtype: Any + + def __init__( + self, + coord_names: Iterable[Hashable], + dim_size: Mapping[str, int], + dtype: Any = None, + ): + self.coord_names = tuple(coord_names) + self.dims = tuple(dim_size) + self.dim_size = dict(dim_size) + + if dtype is None: + dtype = np.dtype(np.float64) + self.dtype = dtype + + def forward(self, dim_positions: dict[str, Any]) -> dict[Hashable, Any]: + """Perform grid -> world coordinate transformation. + + Parameters + ---------- + dim_positions : dict + Grid location(s) along each dimension (axis). + + Returns + ------- + coord_labels : dict + World coordinate labels. + + """ + # TODO: cache the results in order to avoid re-computing + # all labels when accessing the values of each coordinate one at a time + raise NotImplementedError + + def reverse(self, coord_labels: dict[Hashable, Any]) -> dict[str, Any]: + """Perform world -> grid coordinate reverse transformation. + + Parameters + ---------- + labels : dict + World coordinate labels. + + Returns + ------- + dim_positions : dict + Grid relative location(s) along each dimension (axis). + + """ + raise NotImplementedError + + def equals(self, other: "CoordinateTransform") -> bool: + """Check equality with another CoordinateTransform of the same kind.""" + raise NotImplementedError + + def generate_coords(self, dims: tuple[str] | None = None) -> dict[Hashable, Any]: + """Compute all coordinate labels at once.""" + if dims is None: + dims = self.dims + + positions = np.meshgrid( + *[np.arange(self.dim_size[d]) for d in dims], + indexing="ij", + ) + dim_positions = {dim: positions[i] for i, dim in enumerate(dims)} + + return self.forward(dim_positions) diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index a6dec863aec..af622aaca8b 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -14,7 +14,9 @@ from xarray.core import formatting from xarray.core.alignment import Aligner +from xarray.core.coordinate_transform import CoordinateTransform from xarray.core.indexes import ( + CoordinateTransformIndex, Index, Indexes, PandasIndex, @@ -356,7 +358,7 @@ def _construct_direct( def from_pandas_multiindex(cls, midx: pd.MultiIndex, dim: Hashable) -> Self: """Wrap a pandas multi-index as Xarray coordinates (dimension + levels). - The returned coordinates can be directly assigned to a + The returned coordinate variables can be directly assigned to a :py:class:`~xarray.Dataset` or :py:class:`~xarray.DataArray` via the ``coords`` argument of their constructor. @@ -380,6 +382,28 @@ def from_pandas_multiindex(cls, midx: pd.MultiIndex, dim: Hashable) -> Self: return cls(coords=variables, indexes=indexes) + @classmethod + def from_transform(cls, transform: CoordinateTransform) -> Self: + """Wrap a coordinate transform as Xarray (lazy) coordinates. + + The returned coordinate variables can be directly assigned to a + :py:class:`~xarray.Dataset` or :py:class:`~xarray.DataArray` via the + ``coords`` argument of their constructor. + + Parameters + ---------- + transform : :py:class:`CoordinateTransform` + Xarray coordinate transform object. + + Returns + ------- + coords : Coordinates + A collection of Xarray indexed coordinates created from the transform. + + """ + index = CoordinateTransformIndex(transform) + return index.create_coords() + @property def _names(self) -> set[Hashable]: return self._data._coord_names diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 5abc2129e3e..4f2bba20844 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -10,7 +10,9 @@ import pandas as pd from xarray.core import formatting, nputils, utils +from xarray.core.coordinate_transform import CoordinateTransform from xarray.core.indexing import ( + CoordinateTransformIndexingAdapter, IndexSelResult, PandasIndexingAdapter, PandasMultiIndexingAdapter, @@ -24,6 +26,7 @@ ) if TYPE_CHECKING: + from xarray.core.coordinate import Coordinates from xarray.core.types import ErrorOptions, JoinOptions, Self from xarray.core.variable import Variable @@ -1372,6 +1375,137 @@ def rename(self, name_dict, dims_dict): ) +class CoordinateTransformIndex(Index): + """Helper class for creating Xarray indexes based on coordinate transforms. + + - wraps a :py:class:`CoordinateTransform` instance + - takes care of creating the index (lazy) coordinates + - supports point-wise label-based selection + - supports exact alignment only, by comparing indexes based on their transform + (not on their explicit coordinate labels) + + """ + + transform: CoordinateTransform + + def __init__( + self, + transform: CoordinateTransform, + ): + self.transform = transform + + def create_variables( + self, variables: Mapping[Any, Variable] | None = None + ) -> IndexVars: + from xarray.core.variable import Variable + + new_variables = {} + + for name in self.transform.coord_names: + # copy attributes, if any + attrs: Mapping[Hashable, Any] | None + + if variables is not None and name in variables: + var = variables[name] + attrs = var.attrs + else: + attrs = None + + data = CoordinateTransformIndexingAdapter(self.transform, name) + new_variables[name] = Variable(self.transform.dims, data, attrs=attrs) + + return new_variables + + def create_coordinates(self) -> Coordinates: + # TODO: remove this alias before merging https://github.com/pydata/xarray/pull/9543! + # (we keep it there so it doesn't break the code of those who are experimenting with this) + return self.create_coords() + + def create_coords(self) -> Coordinates: + # TODO: move this in xarray.Index base class? + from xarray.core.coordinates import Coordinates + + variables = self.create_variables() + indexes = {name: self for name in variables} + return Coordinates(coords=variables, indexes=indexes) + + def isel( + self, indexers: Mapping[Any, int | slice | np.ndarray | Variable] + ) -> Self | None: + # TODO: support returning a new index (e.g., possible to re-calculate the + # the transform or calculate another transform on a reduced dimension space) + return None + + def sel( + self, labels: dict[Any, Any], method=None, tolerance=None + ) -> IndexSelResult: + from xarray.core.dataarray import DataArray + from xarray.core.variable import Variable + + if method != "nearest": + raise ValueError( + "CoordinateTransformIndex only supports selection with method='nearest'" + ) + + labels_set = set(labels) + coord_names_set = set(self.transform.coord_names) + + missing_labels = coord_names_set - labels_set + if missing_labels: + missing_labels_str = ",".join([f"{name}" for name in missing_labels]) + raise ValueError(f"missing labels for coordinate(s): {missing_labels_str}.") + + label0_obj = next(iter(labels.values())) + dim_size0 = getattr(label0_obj, "sizes", None) + + is_xr_obj = [ + isinstance(label, DataArray | Variable) for label in labels.values() + ] + if not all(is_xr_obj): + raise TypeError( + "CoordinateTransformIndex only supports advanced (point-wise) indexing " + "with either xarray.DataArray or xarray.Variable objects." + ) + dim_size = [getattr(label, "sizes", None) for label in labels.values()] + if any([ds != dim_size0 for ds in dim_size]): + raise ValueError( + "CoordinateTransformIndex only supports advanced (point-wise) indexing " + "with xarray.DataArray or xarray.Variable objects of macthing dimensions." + ) + + coord_labels = { + name: labels[name].values for name in self.transform.coord_names + } + dim_positions = self.transform.reverse(coord_labels) + + results = {} + dims0 = tuple(dim_size0) + for dim, pos in dim_positions.items(): + # TODO: rounding the decimal positions is not always the behavior we expect + # (there are different ways to represent implicit intervals) + # we should probably make this customizable. + pos = np.round(pos).astype("int") + if isinstance(label0_obj, Variable): + xr_pos = Variable(dims0, pos) + else: + # dataarray + xr_pos = DataArray(pos, dims=dims0) + results[dim] = xr_pos + + return IndexSelResult(results) + + def equals(self, other: Self) -> bool: + return self.transform.equals(other.transform) + + def rename( + self, + name_dict: Mapping[Any, Hashable], + dims_dict: Mapping[Any, Hashable], + ) -> Self: + # TODO: maybe update self.transform coord_names, dim_size and dims attributes + return self + + def create_default_index_implicit( dim_variable: Variable, all_variables: Mapping | Iterable[Hashable] | None = None, diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 67912908a2b..04677bb8d60 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -15,6 +15,7 @@ import pandas as pd from xarray.core import duck_array_ops +from xarray.core.coordinate_transform import CoordinateTransform from xarray.core.nputils import NumpyVIndexAdapter from xarray.core.options import OPTIONS from xarray.core.types import T_Xarray @@ -1303,6 +1304,42 @@ def _decompose_outer_indexer( return (BasicIndexer(tuple(backend_indexer)), OuterIndexer(tuple(np_indexer))) +def _posify_indices(indices: np.typing.ArrayLike, size: int) -> np.ndarray: + """Convert negative indices by their equivalent positive indices. + + Note: the resulting indices may still be out of bounds (< 0 or >= size). + + """ + return np.where(indices < 0, size + indices, indices) + + +def _check_bounds(indices, size): + """Check if the given indices are all within the array boundaries.""" + if np.any((indices < 0) | (indices >= size)): + raise IndexError("out of bounds index") + + +def _arrayize_outer_indexer(indexer: OuterIndexer, shape) -> OuterIndexer: + """Return a similar oindex with after replacing slices by arrays and + negative indices by their corresponding positive indices. + + Also check if array indices are within bounds. + + """ + new_key = [] + + for axis, value in enumerate(indexer.tuple): + size = shape[axis] + if isinstance(value, slice): + value = _expand_slice(value, size) + else: + value = _posify_indices(value, size) + _check_bounds(value, size) + new_key.append(value) + + return OuterIndexer(tuple(new_key)) + + def _arrayize_vectorized_indexer( indexer: VectorizedIndexer, shape: _Shape ) -> VectorizedIndexer: @@ -1921,3 +1958,116 @@ def copy(self, deep: bool = True) -> Self: # see PandasIndexingAdapter.copy array = self.array.copy(deep=True) if deep else self.array return type(self)(array, self._dtype, self.level) + + +class CoordinateTransformIndexingAdapter(ExplicitlyIndexedNDArrayMixin): + """Wrap a CoordinateTransform as a lazy coordinate array. + + Supports explicit indexing (both outer and vectorized). + + """ + + _transform: CoordinateTransform + _coord_name: Hashable + _dims: tuple[str, ...] + + def __init__( + self, + transform: CoordinateTransform, + coord_name: Hashable, + dims: tuple[str] | None = None, + ): + self._transform = transform + self._coord_name = coord_name + self._dims = dims or transform.dims + + @property + def dtype(self) -> np.dtype: + return self._transform.dtype + + @property + def shape(self): + return tuple(self._transform.dim_size.values()) + + def get_duck_array(self) -> np.ndarray: + all_coords = self._transform.generate_coords(dims=self._dims) + return np.asarray(all_coords[self._coord_name]) + + def _oindex_get(self, indexer: OuterIndexer): + expanded_indexer_ = OuterIndexer(expanded_indexer(indexer.tuple, self.ndim)) + array_indexer = _arrayize_outer_indexer(expanded_indexer_, self.shape) + + positions = np.meshgrid(*array_indexer.tuple, indexing="ij") + dim_positions = { + dim: pos for dim, pos in zip(self._dims, positions, strict=False) + } + + result = self._transform.forward(dim_positions) + return np.asarray(result[self._coord_name]).squeeze() + + def _oindex_set(self, indexer: OuterIndexer, value: Any) -> None: + raise TypeError( + "setting values is not supported on coordinate transform arrays." + ) + + def _vindex_get(self, indexer: VectorizedIndexer): + expanded_indexer_ = VectorizedIndexer( + expanded_indexer(indexer.tuple, self.ndim) + ) + array_indexer = _arrayize_vectorized_indexer(expanded_indexer_, self.shape) + + dim_positions = {} + for i, (dim, pos) in enumerate( + zip(self._dims, array_indexer.tuple, strict=False) + ): + pos = _posify_indices(pos, self.shape[i]) + _check_bounds(pos, self.shape[i]) + dim_positions[dim] = pos + + result = self._transform.forward(dim_positions) + return np.asarray(result[self._coord_name]) + + def _vindex_set(self, indexer: VectorizedIndexer, value: Any) -> None: + raise TypeError( + "setting values is not supported on coordinate transform arrays." + ) + + def __getitem__(self, indexer: ExplicitIndexer): + # TODO: make it lazy (i.e., re-calculate and re-wrap the transform) when possible? + self._check_and_raise_if_non_basic_indexer(indexer) + + # also works with basic indexing + return self._oindex_get(OuterIndexer(indexer.tuple)) + + def __setitem__(self, indexer: ExplicitIndexer, value: Any) -> None: + raise TypeError( + "setting values is not supported on coordinate transform arrays." + ) + + def transpose(self, order): + new_dims = tuple([self._dims[i] for i in order]) + return type(self)(self._transform, self._coord_name, new_dims) + + def __repr__(self: Any) -> str: + return f"{type(self).__name__}(transform={self._transform!r})" + + def _get_array_subset(self) -> np.ndarray: + threshold = max(100, OPTIONS["display_values_threshold"] + 2) + if self.size > threshold: + pos = threshold // 2 + flat_indices = np.concatenate( + [np.arange(0, pos), np.arange(self.size - pos, self.size)] + ) + subset = self.vindex[ + VectorizedIndexer(np.unravel_index(flat_indices, self.shape)) + ] + else: + subset = self + + return np.asarray(subset) + + def _repr_inline_(self, max_width: int) -> str: + """Good to see some labels even for a lazy coordinate.""" + from xarray.core.formatting import format_array_flat + + return format_array_flat(self._get_array_subset(), max_width) diff --git a/xarray/indexes/__init__.py b/xarray/indexes/__init__.py index b1bf7a1af11..e2857b8602b 100644 --- a/xarray/indexes/__init__.py +++ b/xarray/indexes/__init__.py @@ -3,6 +3,11 @@ """ -from xarray.core.indexes import Index, PandasIndex, PandasMultiIndex +from xarray.core.indexes import ( + CoordinateTransformIndex, + Index, + PandasIndex, + PandasMultiIndex, +) -__all__ = ["Index", "PandasIndex", "PandasMultiIndex"] +__all__ = ["CoordinateTransformIndex", "Index", "PandasIndex", "PandasMultiIndex"]