Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport PR #2950 on branch 1.10.x (Check that aggregate is only called on anndata) #2952

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions scanpy/get/_aggregated.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@
return X**power if isinstance(X, np.ndarray) else X.power(power)


@singledispatch
def aggregate(
adata: AnnData,
by: str | Collection[str],
Expand Down Expand Up @@ -232,6 +231,11 @@

Note that this filters out any combination of groups that wasn't present in the original data.
"""
if not isinstance(adata, AnnData):
raise NotImplementedError(
"sc.get.aggregate is currently only implemented for AnnData input, "
f"was passed {type(adata)}."
)
if axis is None:
axis = 1 if varm else 0
axis, axis_name = _resolve_axis(axis)
Expand Down Expand Up @@ -260,7 +264,7 @@
dim_df = getattr(adata, axis_name)
categorical, new_label_df = _combine_categories(dim_df, by)
# Actual computation
layers = aggregate(
layers = _aggregate(
data,
by=categorical,
func=func,
Expand Down Expand Up @@ -288,13 +292,25 @@
return result


@aggregate.register(pd.DataFrame)
@singledispatch
def _aggregate(
data,
by: pd.Categorical,
func: AggType | Iterable[AggType],
*,
mask: NDArray[np.bool_] | None = None,
dof: int = 1,
):
raise NotImplementedError(f"Data type {type(data)} not supported for aggregation")

Check warning on line 304 in scanpy/get/_aggregated.py

View check run for this annotation

Codecov / codecov/patch

scanpy/get/_aggregated.py#L304

Added line #L304 was not covered by tests


@_aggregate.register(pd.DataFrame)
def aggregate_df(data, by, func, *, mask=None, dof=1):
return aggregate(data.values, by, func, mask=mask, dof=dof)
return _aggregate(data.values, by, func, mask=mask, dof=dof)


@aggregate.register(np.ndarray)
@aggregate.register(sparse.spmatrix)
@_aggregate.register(np.ndarray)
@_aggregate.register(sparse.spmatrix)
def aggregate_array(
data,
by: pd.Categorical,
Expand Down
6 changes: 6 additions & 0 deletions scanpy/tests/test_aggregated.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,3 +454,9 @@ def test_aggregate_obsm_labels():
)
result = sc.get.aggregate(adata, by="labels", func="sum", obsm="entry")
assert_equal(expected, result)


def test_dispatch_not_implemented():
adata = sc.datasets.blobs()
with pytest.raises(NotImplementedError):
sc.get.aggregate(adata.X, adata.obs["blobs"], "sum")