Skip to content

Commit

Permalink
Backport PR scverse#2950: Check that aggregate is only called on anndata
Browse files Browse the repository at this point in the history
  • Loading branch information
ivirshup authored and meeseeksmachine committed Mar 25, 2024
1 parent 60a0042 commit 9322058
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 6 deletions.
28 changes: 22 additions & 6 deletions scanpy/get/_aggregated.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ def _power(X: Array, power: float | int) -> Array:
return X**power if isinstance(X, np.ndarray) else X.power(power)


@singledispatch
def aggregate(
adata: AnnData,
by: str | Collection[str],
Expand Down Expand Up @@ -232,6 +231,11 @@ def aggregate(
Note that this filters out any combination of groups that wasn't present in the original data.
"""
if not isinstance(adata, AnnData):
raise NotImplementedError(
"sc.get.aggregate is currently only implemented for AnnData input, "
f"was passed {type(adata)}."
)
if axis is None:
axis = 1 if varm else 0
axis, axis_name = _resolve_axis(axis)
Expand Down Expand Up @@ -260,7 +264,7 @@ def aggregate(
dim_df = getattr(adata, axis_name)
categorical, new_label_df = _combine_categories(dim_df, by)
# Actual computation
layers = aggregate(
layers = _aggregate(
data,
by=categorical,
func=func,
Expand Down Expand Up @@ -288,13 +292,25 @@ def aggregate(
return result


@aggregate.register(pd.DataFrame)
@singledispatch
def _aggregate(
data,
by: pd.Categorical,
func: AggType | Iterable[AggType],
*,
mask: NDArray[np.bool_] | None = None,
dof: int = 1,
):
raise NotImplementedError(f"Data type {type(data)} not supported for aggregation")


@_aggregate.register(pd.DataFrame)
def aggregate_df(data, by, func, *, mask=None, dof=1):
return aggregate(data.values, by, func, mask=mask, dof=dof)
return _aggregate(data.values, by, func, mask=mask, dof=dof)


@aggregate.register(np.ndarray)
@aggregate.register(sparse.spmatrix)
@_aggregate.register(np.ndarray)
@_aggregate.register(sparse.spmatrix)
def aggregate_array(
data,
by: pd.Categorical,
Expand Down
6 changes: 6 additions & 0 deletions scanpy/tests/test_aggregated.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,3 +454,9 @@ def test_aggregate_obsm_labels():
)
result = sc.get.aggregate(adata, by="labels", func="sum", obsm="entry")
assert_equal(expected, result)


def test_dispatch_not_implemented():
adata = sc.datasets.blobs()
with pytest.raises(NotImplementedError):
sc.get.aggregate(adata.X, adata.obs["blobs"], "sum")

0 comments on commit 9322058

Please sign in to comment.