From 9279b379ffb22ba5fcf7331e9f87dbb18bd61434 Mon Sep 17 00:00:00 2001 From: "Lumberbot (aka Jack)" <39504233+meeseeksmachine@users.noreply.github.com> Date: Mon, 25 Mar 2024 12:23:06 +0100 Subject: [PATCH] Backport PR #2950: Check that aggregate is only called on anndata (#2952) Co-authored-by: Isaac Virshup --- scanpy/get/_aggregated.py | 28 ++++++++++++++++++++++------ scanpy/tests/test_aggregated.py | 6 ++++++ 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/scanpy/get/_aggregated.py b/scanpy/get/_aggregated.py index 6d60328791..5530059c2c 100644 --- a/scanpy/get/_aggregated.py +++ b/scanpy/get/_aggregated.py @@ -156,7 +156,6 @@ def _power(X: Array, power: float | int) -> Array: return X**power if isinstance(X, np.ndarray) else X.power(power) -@singledispatch def aggregate( adata: AnnData, by: str | Collection[str], @@ -232,6 +231,11 @@ def aggregate( Note that this filters out any combination of groups that wasn't present in the original data. """ + if not isinstance(adata, AnnData): + raise NotImplementedError( + "sc.get.aggregate is currently only implemented for AnnData input, " + f"was passed {type(adata)}." + ) if axis is None: axis = 1 if varm else 0 axis, axis_name = _resolve_axis(axis) @@ -260,7 +264,7 @@ def aggregate( dim_df = getattr(adata, axis_name) categorical, new_label_df = _combine_categories(dim_df, by) # Actual computation - layers = aggregate( + layers = _aggregate( data, by=categorical, func=func, @@ -288,13 +292,25 @@ def aggregate( return result -@aggregate.register(pd.DataFrame) +@singledispatch +def _aggregate( + data, + by: pd.Categorical, + func: AggType | Iterable[AggType], + *, + mask: NDArray[np.bool_] | None = None, + dof: int = 1, +): + raise NotImplementedError(f"Data type {type(data)} not supported for aggregation") + + +@_aggregate.register(pd.DataFrame) def aggregate_df(data, by, func, *, mask=None, dof=1): - return aggregate(data.values, by, func, mask=mask, dof=dof) + return _aggregate(data.values, by, func, mask=mask, dof=dof) -@aggregate.register(np.ndarray) -@aggregate.register(sparse.spmatrix) +@_aggregate.register(np.ndarray) +@_aggregate.register(sparse.spmatrix) def aggregate_array( data, by: pd.Categorical, diff --git a/scanpy/tests/test_aggregated.py b/scanpy/tests/test_aggregated.py index d0e01bb604..99d36947c4 100644 --- a/scanpy/tests/test_aggregated.py +++ b/scanpy/tests/test_aggregated.py @@ -454,3 +454,9 @@ def test_aggregate_obsm_labels(): ) result = sc.get.aggregate(adata, by="labels", func="sum", obsm="entry") assert_equal(expected, result) + + +def test_dispatch_not_implemented(): + adata = sc.datasets.blobs() + with pytest.raises(NotImplementedError): + sc.get.aggregate(adata.X, adata.obs["blobs"], "sum")