Skip to content

Commit

Permalink
Add sample probabilities (#3410)
Browse files Browse the repository at this point in the history
  • Loading branch information
flying-sheep authored Dec 20, 2024
1 parent ac4c629 commit 397d703
Show file tree
Hide file tree
Showing 8 changed files with 102 additions and 34 deletions.
1 change: 1 addition & 0 deletions docs/release-notes/3410.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add sampling probabilities/mask parameter `p` to {func}`~scanpy.pp.sample` {smaller}`P Angerer`
3 changes: 1 addition & 2 deletions src/scanpy/get/_aggregated.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,7 @@ def aggregate(
if axis is None:
axis = 1 if varm else 0
axis, axis_name = _resolve_axis(axis)
if mask is not None:
mask = _check_mask(adata, mask, axis_name)
mask = _check_mask(adata, mask, axis_name)
data = adata.X
if sum(p is not None for p in [varm, obsm, layer]) > 1:
raise TypeError("Please only provide one (or none) of varm, obsm, or layer")
Expand Down
49 changes: 37 additions & 12 deletions src/scanpy/get/get.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, TypeVar

import numpy as np
import pandas as pd
from anndata import AnnData
from numpy.typing import NDArray
from packaging.version import Version
from scipy.sparse import spmatrix

Expand All @@ -16,7 +17,11 @@

from anndata._core.sparse_dataset import BaseCompressedSparseDataset
from anndata._core.views import ArrayView
from numpy.typing import NDArray
from scipy.sparse import csc_matrix, csr_matrix

from .._compat import DaskArray

CSMatrix = csr_matrix | csc_matrix

# --------------------------------------------------------------------------------
# Plotting data helpers
Expand Down Expand Up @@ -485,42 +490,62 @@ def _set_obs_rep(
raise AssertionError(msg)


M = TypeVar("M", bound=NDArray[np.bool_] | NDArray[np.floating] | pd.Series | None)


def _check_mask(
data: AnnData | np.ndarray,
mask: NDArray[np.bool_] | str,
data: AnnData | np.ndarray | CSMatrix | DaskArray,
mask: str | M,
dim: Literal["obs", "var"],
) -> NDArray[np.bool_]: # Could also be a series, but should be one or the other
*,
allow_probabilities: bool = False,
) -> M: # Could also be a series, but should be one or the other
"""
Validate mask argument
Params
------
data
Annotated data matrix or numpy array.
mask
The mask. Either an appropriatley sized boolean array, or name of a column which will be used to mask.
Mask (or probabilities if `allow_probabilities=True`).
Either an appropriatley sized array, or name of a column.
dim
The dimension being masked.
allow_probabilities
Whether to allow probabilities as `mask`
"""
if mask is None:
return mask
desc = "mask/probabilities" if allow_probabilities else "mask"

if isinstance(mask, str):
if not isinstance(data, AnnData):
msg = "Cannot refer to mask with string without providing anndata object as argument"
msg = f"Cannot refer to {desc} with string without providing anndata object as argument"
raise ValueError(msg)

annot: pd.DataFrame = getattr(data, dim)
if mask not in annot.columns:
msg = (
f"Did not find `adata.{dim}[{mask!r}]`. "
f"Either add the mask first to `adata.{dim}`"
"or consider using the mask argument with a boolean array."
f"Either add the {desc} first to `adata.{dim}`"
f"or consider using the {desc} argument with an array."
)
raise ValueError(msg)
mask_array = annot[mask].to_numpy()
else:
if len(mask) != data.shape[0 if dim == "obs" else 1]:
raise ValueError("The shape of the mask do not match the data.")
msg = f"The shape of the {desc} do not match the data."
raise ValueError(msg)
mask_array = mask

if not pd.api.types.is_bool_dtype(mask_array.dtype):
raise ValueError("Mask array must be boolean.")
is_bool = pd.api.types.is_bool_dtype(mask_array.dtype)
if not allow_probabilities and not is_bool:
msg = "Mask array must be boolean."
raise ValueError(msg)
elif allow_probabilities and not (
is_bool or pd.api.types.is_float_dtype(mask_array.dtype)
):
msg = f"{desc} array must be boolean or floating point."
raise ValueError(msg)

return mask_array
3 changes: 1 addition & 2 deletions src/scanpy/plotting/_tools/scatterplots.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,7 @@ def embedding(
# Checking the mask format and if used together with groups
if groups is not None and mask_obs is not None:
raise ValueError("Groups and mask arguments are incompatible.")
if mask_obs is not None:
mask_obs = _check_mask(adata, mask_obs, "obs")
mask_obs = _check_mask(adata, mask_obs, "obs")

# Figure out if we're using raw
if use_raw is None:
Expand Down
2 changes: 1 addition & 1 deletion src/scanpy/preprocessing/_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ def scale_array(
):
if copy:
X = X.copy()
mask_obs = _check_mask(X, mask_obs, "obs")
if mask_obs is not None:
mask_obs = _check_mask(X, mask_obs, "obs")
scale_rv = scale_array(
X[mask_obs, :],
zero_center=zero_center,
Expand Down
16 changes: 14 additions & 2 deletions src/scanpy/preprocessing/_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
sanitize_anndata,
view_to_actual,
)
from ..get import _get_obs_rep, _set_obs_rep
from ..get import _check_mask, _get_obs_rep, _set_obs_rep
from ._distributed import materialize_as_ndarray
from ._utils import _to_dense

Expand Down Expand Up @@ -838,6 +838,7 @@ def sample(
copy: Literal[False] = False,
replace: bool = False,
axis: Literal["obs", 0, "var", 1] = "obs",
p: str | NDArray[np.bool_] | NDArray[np.floating] | None = None,
) -> None: ...
@overload
def sample(
Expand All @@ -849,6 +850,7 @@ def sample(
copy: Literal[True],
replace: bool = False,
axis: Literal["obs", 0, "var", 1] = "obs",
p: str | NDArray[np.bool_] | NDArray[np.floating] | None = None,
) -> AnnData: ...
@overload
def sample(
Expand All @@ -860,6 +862,7 @@ def sample(
copy: bool = False,
replace: bool = False,
axis: Literal["obs", 0, "var", 1] = "obs",
p: str | NDArray[np.bool_] | NDArray[np.floating] | None = None,
) -> tuple[A, NDArray[np.int64]]: ...
def sample(
data: AnnData | np.ndarray | CSMatrix | DaskArray,
Expand All @@ -870,6 +873,7 @@ def sample(
copy: bool = False,
replace: bool = False,
axis: Literal["obs", 0, "var", 1] = "obs",
p: str | NDArray[np.bool_] | NDArray[np.floating] | None = None,
) -> AnnData | None | tuple[np.ndarray | CSMatrix | DaskArray, NDArray[np.int64]]:
"""\
Sample observations or variables with or without replacement.
Expand All @@ -881,6 +885,7 @@ def sample(
Rows correspond to cells and columns to genes.
fraction
Sample to this `fraction` of the number of observations or variables.
(All of them, even if there are `0`s/`False`s in `p`.)
This can be larger than 1.0, if `replace=True`.
See `axis` and `replace`.
n
Expand All @@ -894,6 +899,10 @@ def sample(
If True, samples are drawn with replacement.
axis
Sample `obs`\\ ervations (axis 0) or `var`\\ iables (axis 1).
p
Drawing probabilities (floats) or mask (bools).
Either an `axis`-sized array, or the name of a column.
If `p` is an array of probabilities, it must sum to 1.
Returns
-------
Expand All @@ -910,6 +919,9 @@ def sample(
msg = "Inplace sampling (`copy=False`) is not implemented for backed objects."
raise NotImplementedError(msg)
axis, axis_name = _resolve_axis(axis)
p = _check_mask(data, p, dim=axis_name, allow_probabilities=True)
if p is not None and p.dtype == bool:
p = p.astype(np.float64) / p.sum()
old_n = data.shape[axis]
match (fraction, n):
case (None, None):
Expand All @@ -933,7 +945,7 @@ def sample(

# actually do subsampling
rng = np.random.default_rng(rng)
indices = rng.choice(old_n, size=n, replace=replace)
indices = rng.choice(old_n, size=n, replace=replace, p=p)

# overload 1: inplace AnnData subset
if not copy and isinstance(data, AnnData):
Expand Down
3 changes: 1 addition & 2 deletions src/scanpy/tools/_rank_genes_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,8 +594,7 @@ def rank_genes_groups(
>>> # to visualize the results
>>> sc.pl.rank_genes_groups(adata)
"""
if mask_var is not None:
mask_var = _check_mask(adata, mask_var, "var")
mask_var = _check_mask(adata, mask_var, "var")

if use_raw is None:
use_raw = adata.raw is not None
Expand Down
59 changes: 46 additions & 13 deletions tests/test_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from collections.abc import Callable
from typing import Any, Literal

from numpy.typing import NDArray

CSMatrix = sp.csc_matrix | sp.csr_matrix


Expand Down Expand Up @@ -144,31 +146,55 @@ def test_normalize_per_cell():
assert adata.X.sum(axis=1).tolist() == adata_sparse.X.sum(axis=1).A1.tolist()


def _random_probs(n: int, frac_zero: float) -> NDArray[np.float64]:
"""
Generate a random probability distribution of `n` values between 0 and 1.
"""
probs = np.random.randint(0, 10000, n).astype(np.float64)
probs[probs < np.quantile(probs, frac_zero)] = 0
probs /= probs.sum()
np.testing.assert_almost_equal(probs.sum(), 1)
return probs


@pytest.mark.parametrize("array_type", ARRAY_TYPES)
@pytest.mark.parametrize("which", ["copy", "inplace", "array"])
@pytest.mark.parametrize(
("axis", "fraction", "n", "replace", "expected"),
("axis", "f_or_n", "replace"),
[
pytest.param(0, 40, False, id="obs-40-no_replace"),
pytest.param(0, 0.1, False, id="obs-0.1-no_replace"),
pytest.param(0, 201, True, id="obs-201-replace"),
pytest.param(0, 1, True, id="obs-1-replace"),
pytest.param(1, 10, False, id="var-10-no_replace"),
pytest.param(1, 11, True, id="var-11-replace"),
pytest.param(1, 2.0, True, id="var-2.0-replace"),
],
)
@pytest.mark.parametrize(
"ps",
[
pytest.param(0, None, 40, False, 40, id="obs-40-no_replace"),
pytest.param(0, 0.1, None, False, 20, id="obs-0.1-no_replace"),
pytest.param(0, None, 201, True, 201, id="obs-201-replace"),
pytest.param(0, None, 1, True, 1, id="obs-1-replace"),
pytest.param(1, None, 10, False, 10, id="var-10-no_replace"),
pytest.param(1, None, 11, True, 11, id="var-11-replace"),
pytest.param(1, 2.0, None, True, 20, id="var-2.0-replace"),
dict(obs=None, var=None),
dict(obs=np.tile([True, False], 100), var=np.tile([True, False], 5)),
dict(obs=_random_probs(200, 0.3), var=_random_probs(10, 0.7)),
],
ids=["all", "mask", "p"],
)
def test_sample(
*,
request: pytest.FixtureRequest,
array_type: Callable[[np.ndarray], np.ndarray | CSMatrix],
which: Literal["copy", "inplace", "array"],
axis: Literal[0, 1],
fraction: float | None,
n: int | None,
f_or_n: float | int, # noqa: PYI041
replace: bool,
expected: int,
ps: dict[Literal["obs", "var"], NDArray[np.bool_] | None],
):
adata = AnnData(array_type(np.ones((200, 10))))
p = ps["obs" if axis == 0 else "var"]
expected = int(adata.shape[axis] * f_or_n) if isinstance(f_or_n, float) else f_or_n
if p is not None and not replace and expected > (n_possible := (p != 0).sum()):
request.applymarker(pytest.xfail(f"Can’t draw {expected} out of {n_possible}"))

# ignoring this warning declaratively is a pain so do it here
if find_spec("dask"):
Expand All @@ -182,12 +208,13 @@ def test_sample(
)
rv = sc.pp.sample(
adata.X if which == "array" else adata,
fraction,
n=n,
f_or_n if isinstance(f_or_n, float) else None,
n=f_or_n if isinstance(f_or_n, int) else None,
replace=replace,
axis=axis,
# `copy` only effects AnnData inputs
copy=dict(copy=True, inplace=False, array=False)[which],
p=p,
)

match which:
Expand Down Expand Up @@ -232,6 +259,12 @@ def test_sample(
r"`fraction=-0\.3` needs to be nonnegative",
id="frac<0",
),
pytest.param(
dict(n=3, p=np.ones(200, dtype=np.int32)),
ValueError,
r"mask/probabilities array must be boolean or floating point",
id="type(p)",
),
],
)
def test_sample_error(args: dict[str, Any], exc: type[Exception], pattern: str):
Expand Down

0 comments on commit 397d703

Please sign in to comment.