diff --git a/docs/release-notes/3410.feature.md b/docs/release-notes/3410.feature.md new file mode 100644 index 0000000000..d95ad201ba --- /dev/null +++ b/docs/release-notes/3410.feature.md @@ -0,0 +1 @@ +Add sampling probabilities/mask parameter `p` to {func}`~scanpy.pp.sample` {smaller}`P Angerer` diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index 13ca54b5c4..53a18bb47c 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -263,8 +263,7 @@ def aggregate( if axis is None: axis = 1 if varm else 0 axis, axis_name = _resolve_axis(axis) - if mask is not None: - mask = _check_mask(adata, mask, axis_name) + mask = _check_mask(adata, mask, axis_name) data = adata.X if sum(p is not None for p in [varm, obsm, layer]) > 1: raise TypeError("Please only provide one (or none) of varm, obsm, or layer") diff --git a/src/scanpy/get/get.py b/src/scanpy/get/get.py index f3172ed45e..c36ddde8f8 100644 --- a/src/scanpy/get/get.py +++ b/src/scanpy/get/get.py @@ -2,11 +2,12 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, TypeVar import numpy as np import pandas as pd from anndata import AnnData +from numpy.typing import NDArray from packaging.version import Version from scipy.sparse import spmatrix @@ -16,7 +17,11 @@ from anndata._core.sparse_dataset import BaseCompressedSparseDataset from anndata._core.views import ArrayView - from numpy.typing import NDArray + from scipy.sparse import csc_matrix, csr_matrix + + from .._compat import DaskArray + + CSMatrix = csr_matrix | csc_matrix # -------------------------------------------------------------------------------- # Plotting data helpers @@ -485,11 +490,16 @@ def _set_obs_rep( raise AssertionError(msg) +M = TypeVar("M", bound=NDArray[np.bool_] | NDArray[np.floating] | pd.Series | None) + + def _check_mask( - data: AnnData | np.ndarray, - mask: NDArray[np.bool_] | str, + data: AnnData | np.ndarray | CSMatrix | DaskArray, + mask: str | M, dim: Literal["obs", "var"], -) -> NDArray[np.bool_]: # Could also be a series, but should be one or the other + *, + allow_probabilities: bool = False, +) -> M: # Could also be a series, but should be one or the other """ Validate mask argument Params @@ -497,30 +507,45 @@ def _check_mask( data Annotated data matrix or numpy array. mask - The mask. Either an appropriatley sized boolean array, or name of a column which will be used to mask. + Mask (or probabilities if `allow_probabilities=True`). + Either an appropriatley sized array, or name of a column. dim The dimension being masked. + allow_probabilities + Whether to allow probabilities as `mask` """ + if mask is None: + return mask + desc = "mask/probabilities" if allow_probabilities else "mask" + if isinstance(mask, str): if not isinstance(data, AnnData): - msg = "Cannot refer to mask with string without providing anndata object as argument" + msg = f"Cannot refer to {desc} with string without providing anndata object as argument" raise ValueError(msg) annot: pd.DataFrame = getattr(data, dim) if mask not in annot.columns: msg = ( f"Did not find `adata.{dim}[{mask!r}]`. " - f"Either add the mask first to `adata.{dim}`" - "or consider using the mask argument with a boolean array." + f"Either add the {desc} first to `adata.{dim}`" + f"or consider using the {desc} argument with an array." ) raise ValueError(msg) mask_array = annot[mask].to_numpy() else: if len(mask) != data.shape[0 if dim == "obs" else 1]: - raise ValueError("The shape of the mask do not match the data.") + msg = f"The shape of the {desc} do not match the data." + raise ValueError(msg) mask_array = mask - if not pd.api.types.is_bool_dtype(mask_array.dtype): - raise ValueError("Mask array must be boolean.") + is_bool = pd.api.types.is_bool_dtype(mask_array.dtype) + if not allow_probabilities and not is_bool: + msg = "Mask array must be boolean." + raise ValueError(msg) + elif allow_probabilities and not ( + is_bool or pd.api.types.is_float_dtype(mask_array.dtype) + ): + msg = f"{desc} array must be boolean or floating point." + raise ValueError(msg) return mask_array diff --git a/src/scanpy/plotting/_tools/scatterplots.py b/src/scanpy/plotting/_tools/scatterplots.py index e2564eb17f..b54897678f 100644 --- a/src/scanpy/plotting/_tools/scatterplots.py +++ b/src/scanpy/plotting/_tools/scatterplots.py @@ -150,8 +150,7 @@ def embedding( # Checking the mask format and if used together with groups if groups is not None and mask_obs is not None: raise ValueError("Groups and mask arguments are incompatible.") - if mask_obs is not None: - mask_obs = _check_mask(adata, mask_obs, "obs") + mask_obs = _check_mask(adata, mask_obs, "obs") # Figure out if we're using raw if use_raw is None: diff --git a/src/scanpy/preprocessing/_scale.py b/src/scanpy/preprocessing/_scale.py index d7123d5f65..bac08f246b 100644 --- a/src/scanpy/preprocessing/_scale.py +++ b/src/scanpy/preprocessing/_scale.py @@ -164,8 +164,8 @@ def scale_array( ): if copy: X = X.copy() + mask_obs = _check_mask(X, mask_obs, "obs") if mask_obs is not None: - mask_obs = _check_mask(X, mask_obs, "obs") scale_rv = scale_array( X[mask_obs, :], zero_center=zero_center, diff --git a/src/scanpy/preprocessing/_simple.py b/src/scanpy/preprocessing/_simple.py index 29c267c3f4..821615676a 100644 --- a/src/scanpy/preprocessing/_simple.py +++ b/src/scanpy/preprocessing/_simple.py @@ -30,7 +30,7 @@ sanitize_anndata, view_to_actual, ) -from ..get import _get_obs_rep, _set_obs_rep +from ..get import _check_mask, _get_obs_rep, _set_obs_rep from ._distributed import materialize_as_ndarray from ._utils import _to_dense @@ -838,6 +838,7 @@ def sample( copy: Literal[False] = False, replace: bool = False, axis: Literal["obs", 0, "var", 1] = "obs", + p: str | NDArray[np.bool_] | NDArray[np.floating] | None = None, ) -> None: ... @overload def sample( @@ -849,6 +850,7 @@ def sample( copy: Literal[True], replace: bool = False, axis: Literal["obs", 0, "var", 1] = "obs", + p: str | NDArray[np.bool_] | NDArray[np.floating] | None = None, ) -> AnnData: ... @overload def sample( @@ -860,6 +862,7 @@ def sample( copy: bool = False, replace: bool = False, axis: Literal["obs", 0, "var", 1] = "obs", + p: str | NDArray[np.bool_] | NDArray[np.floating] | None = None, ) -> tuple[A, NDArray[np.int64]]: ... def sample( data: AnnData | np.ndarray | CSMatrix | DaskArray, @@ -870,6 +873,7 @@ def sample( copy: bool = False, replace: bool = False, axis: Literal["obs", 0, "var", 1] = "obs", + p: str | NDArray[np.bool_] | NDArray[np.floating] | None = None, ) -> AnnData | None | tuple[np.ndarray | CSMatrix | DaskArray, NDArray[np.int64]]: """\ Sample observations or variables with or without replacement. @@ -881,6 +885,7 @@ def sample( Rows correspond to cells and columns to genes. fraction Sample to this `fraction` of the number of observations or variables. + (All of them, even if there are `0`s/`False`s in `p`.) This can be larger than 1.0, if `replace=True`. See `axis` and `replace`. n @@ -894,6 +899,10 @@ def sample( If True, samples are drawn with replacement. axis Sample `obs`\\ ervations (axis 0) or `var`\\ iables (axis 1). + p + Drawing probabilities (floats) or mask (bools). + Either an `axis`-sized array, or the name of a column. + If `p` is an array of probabilities, it must sum to 1. Returns ------- @@ -910,6 +919,9 @@ def sample( msg = "Inplace sampling (`copy=False`) is not implemented for backed objects." raise NotImplementedError(msg) axis, axis_name = _resolve_axis(axis) + p = _check_mask(data, p, dim=axis_name, allow_probabilities=True) + if p is not None and p.dtype == bool: + p = p.astype(np.float64) / p.sum() old_n = data.shape[axis] match (fraction, n): case (None, None): @@ -933,7 +945,7 @@ def sample( # actually do subsampling rng = np.random.default_rng(rng) - indices = rng.choice(old_n, size=n, replace=replace) + indices = rng.choice(old_n, size=n, replace=replace, p=p) # overload 1: inplace AnnData subset if not copy and isinstance(data, AnnData): diff --git a/src/scanpy/tools/_rank_genes_groups.py b/src/scanpy/tools/_rank_genes_groups.py index aa4428dad1..2c214fcfdd 100644 --- a/src/scanpy/tools/_rank_genes_groups.py +++ b/src/scanpy/tools/_rank_genes_groups.py @@ -594,8 +594,7 @@ def rank_genes_groups( >>> # to visualize the results >>> sc.pl.rank_genes_groups(adata) """ - if mask_var is not None: - mask_var = _check_mask(adata, mask_var, "var") + mask_var = _check_mask(adata, mask_var, "var") if use_raw is None: use_raw = adata.raw is not None diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index 36283e7ed0..6282c5ccf4 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -29,6 +29,8 @@ from collections.abc import Callable from typing import Any, Literal + from numpy.typing import NDArray + CSMatrix = sp.csc_matrix | sp.csr_matrix @@ -144,31 +146,55 @@ def test_normalize_per_cell(): assert adata.X.sum(axis=1).tolist() == adata_sparse.X.sum(axis=1).A1.tolist() +def _random_probs(n: int, frac_zero: float) -> NDArray[np.float64]: + """ + Generate a random probability distribution of `n` values between 0 and 1. + """ + probs = np.random.randint(0, 10000, n).astype(np.float64) + probs[probs < np.quantile(probs, frac_zero)] = 0 + probs /= probs.sum() + np.testing.assert_almost_equal(probs.sum(), 1) + return probs + + @pytest.mark.parametrize("array_type", ARRAY_TYPES) @pytest.mark.parametrize("which", ["copy", "inplace", "array"]) @pytest.mark.parametrize( - ("axis", "fraction", "n", "replace", "expected"), + ("axis", "f_or_n", "replace"), + [ + pytest.param(0, 40, False, id="obs-40-no_replace"), + pytest.param(0, 0.1, False, id="obs-0.1-no_replace"), + pytest.param(0, 201, True, id="obs-201-replace"), + pytest.param(0, 1, True, id="obs-1-replace"), + pytest.param(1, 10, False, id="var-10-no_replace"), + pytest.param(1, 11, True, id="var-11-replace"), + pytest.param(1, 2.0, True, id="var-2.0-replace"), + ], +) +@pytest.mark.parametrize( + "ps", [ - pytest.param(0, None, 40, False, 40, id="obs-40-no_replace"), - pytest.param(0, 0.1, None, False, 20, id="obs-0.1-no_replace"), - pytest.param(0, None, 201, True, 201, id="obs-201-replace"), - pytest.param(0, None, 1, True, 1, id="obs-1-replace"), - pytest.param(1, None, 10, False, 10, id="var-10-no_replace"), - pytest.param(1, None, 11, True, 11, id="var-11-replace"), - pytest.param(1, 2.0, None, True, 20, id="var-2.0-replace"), + dict(obs=None, var=None), + dict(obs=np.tile([True, False], 100), var=np.tile([True, False], 5)), + dict(obs=_random_probs(200, 0.3), var=_random_probs(10, 0.7)), ], + ids=["all", "mask", "p"], ) def test_sample( *, + request: pytest.FixtureRequest, array_type: Callable[[np.ndarray], np.ndarray | CSMatrix], which: Literal["copy", "inplace", "array"], axis: Literal[0, 1], - fraction: float | None, - n: int | None, + f_or_n: float | int, # noqa: PYI041 replace: bool, - expected: int, + ps: dict[Literal["obs", "var"], NDArray[np.bool_] | None], ): adata = AnnData(array_type(np.ones((200, 10)))) + p = ps["obs" if axis == 0 else "var"] + expected = int(adata.shape[axis] * f_or_n) if isinstance(f_or_n, float) else f_or_n + if p is not None and not replace and expected > (n_possible := (p != 0).sum()): + request.applymarker(pytest.xfail(f"Can’t draw {expected} out of {n_possible}")) # ignoring this warning declaratively is a pain so do it here if find_spec("dask"): @@ -182,12 +208,13 @@ def test_sample( ) rv = sc.pp.sample( adata.X if which == "array" else adata, - fraction, - n=n, + f_or_n if isinstance(f_or_n, float) else None, + n=f_or_n if isinstance(f_or_n, int) else None, replace=replace, axis=axis, # `copy` only effects AnnData inputs copy=dict(copy=True, inplace=False, array=False)[which], + p=p, ) match which: @@ -232,6 +259,12 @@ def test_sample( r"`fraction=-0\.3` needs to be nonnegative", id="frac<0", ), + pytest.param( + dict(n=3, p=np.ones(200, dtype=np.int32)), + ValueError, + r"mask/probabilities array must be boolean or floating point", + id="type(p)", + ), ], ) def test_sample_error(args: dict[str, Any], exc: type[Exception], pattern: str):