From 86d656daa5553aa39804e21f8daab735cddbf6c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=B6k=C3=A7en=20Eraslan?= Date: Thu, 19 Dec 2024 08:48:10 -0800 Subject: [PATCH] Add replace option to subsample and rename function to sample (#943) --- docs/api/deprecated.md | 1 + docs/api/preprocessing.md | 2 +- docs/release-notes/943.feature.md | 1 + pyproject.toml | 4 +- src/scanpy/_compat.py | 41 ++++- src/scanpy/preprocessing/__init__.py | 4 +- .../preprocessing/_deprecated/sampling.py | 60 +++++++ src/scanpy/preprocessing/_simple.py | 165 +++++++++++++----- tests/test_package_structure.py | 1 + tests/test_preprocessing.py | 144 ++++++++++++--- tests/test_utils.py | 42 ++++- 11 files changed, 391 insertions(+), 74 deletions(-) create mode 100644 docs/release-notes/943.feature.md create mode 100644 src/scanpy/preprocessing/_deprecated/sampling.py diff --git a/docs/api/deprecated.md b/docs/api/deprecated.md index 4511f4b3a7..d09c1af405 100644 --- a/docs/api/deprecated.md +++ b/docs/api/deprecated.md @@ -11,4 +11,5 @@ pp.filter_genes_dispersion pp.normalize_per_cell + pp.subsample ``` diff --git a/docs/api/preprocessing.md b/docs/api/preprocessing.md index 4b17567a6b..36e732a6dc 100644 --- a/docs/api/preprocessing.md +++ b/docs/api/preprocessing.md @@ -31,7 +31,7 @@ For visual quality control, see {func}`~scanpy.pl.highest_expr_genes` and pp.normalize_total pp.regress_out pp.scale - pp.subsample + pp.sample pp.downsample_counts ``` diff --git a/docs/release-notes/943.feature.md b/docs/release-notes/943.feature.md new file mode 100644 index 0000000000..4f5474d762 --- /dev/null +++ b/docs/release-notes/943.feature.md @@ -0,0 +1 @@ +{func}`~scanpy.pp.sample` supports both upsampling and downsampling of observations and variables. {func}`~scanpy.pp.subsample` is now deprecated. {smaller}`G Eraslan` & {smaller}`P Angerer` diff --git a/pyproject.toml b/pyproject.toml index f1495442fe..b4b8abd1b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,7 @@ classifiers = [ ] dependencies = [ "anndata>=0.8", - "numpy>=1.23", + "numpy>=1.24", "matplotlib>=3.6", "pandas >=1.5", "scipy>=1.8", @@ -60,7 +60,7 @@ dependencies = [ "networkx>=2.7", "natsort", "joblib", - "numba>=0.56", + "numba>=0.57", "umap-learn>=0.5,!=0.5.0", "pynndescent>=0.5", "packaging>=21.3", diff --git a/src/scanpy/_compat.py b/src/scanpy/_compat.py index d2c69a9e37..9ea7780b0d 100644 --- a/src/scanpy/_compat.py +++ b/src/scanpy/_compat.py @@ -4,7 +4,7 @@ import sys import warnings from dataclasses import dataclass, field -from functools import cache, partial, wraps +from functools import WRAPPER_ASSIGNMENTS, cache, partial, wraps from importlib.util import find_spec from pathlib import Path from typing import TYPE_CHECKING, Literal, ParamSpec, TypeVar, cast, overload @@ -224,3 +224,42 @@ def _numba_threading_layer() -> Layer: f" ({available=}, {numba.config.THREADING_LAYER_PRIORITY=})" ) raise ValueError(msg) + + +def _legacy_numpy_gen( + random_state: _LegacyRandom | None = None, +) -> np.random.Generator: + """Return a random generator that behaves like the legacy one.""" + + if random_state is not None: + if isinstance(random_state, np.random.RandomState): + np.random.set_state(random_state.get_state(legacy=False)) + return _FakeRandomGen(random_state) + np.random.seed(random_state) + return _FakeRandomGen(np.random.RandomState(np.random.get_bit_generator())) + + +class _FakeRandomGen(np.random.Generator): + _state: np.random.RandomState + + def __init__(self, random_state: np.random.RandomState) -> None: + self._state = random_state + + @classmethod + def _delegate(cls) -> None: + for name, meth in np.random.Generator.__dict__.items(): + if name.startswith("_") or not callable(meth): + continue + + def mk_wrapper(name: str): + # Old pytest versions try to run the doctests + @wraps(meth, assigned=set(WRAPPER_ASSIGNMENTS) - {"__doc__"}) + def wrapper(self: _FakeRandomGen, *args, **kwargs): + return getattr(self._state, name)(*args, **kwargs) + + return wrapper + + setattr(cls, name, mk_wrapper(name)) + + +_FakeRandomGen._delegate() diff --git a/src/scanpy/preprocessing/__init__.py b/src/scanpy/preprocessing/__init__.py index 8c396d8640..4307cbb6c9 100644 --- a/src/scanpy/preprocessing/__init__.py +++ b/src/scanpy/preprocessing/__init__.py @@ -3,6 +3,7 @@ from ..neighbors import neighbors from ._combat import combat from ._deprecated.highly_variable_genes import filter_genes_dispersion +from ._deprecated.sampling import subsample from ._highly_variable_genes import highly_variable_genes from ._normalization import normalize_total from ._pca import pca @@ -17,8 +18,8 @@ log1p, normalize_per_cell, regress_out, + sample, sqrt, - subsample, ) __all__ = [ @@ -40,6 +41,7 @@ "log1p", "normalize_per_cell", "regress_out", + "sample", "scale", "sqrt", "subsample", diff --git a/src/scanpy/preprocessing/_deprecated/sampling.py b/src/scanpy/preprocessing/_deprecated/sampling.py new file mode 100644 index 0000000000..02619a2364 --- /dev/null +++ b/src/scanpy/preprocessing/_deprecated/sampling.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from ..._compat import _legacy_numpy_gen, old_positionals +from .._simple import sample + +if TYPE_CHECKING: + import numpy as np + from anndata import AnnData + from numpy.typing import NDArray + from scipy.sparse import csc_matrix, csr_matrix + + from ..._compat import _LegacyRandom + + CSMatrix = csr_matrix | csc_matrix + + +@old_positionals("n_obs", "random_state", "copy") +def subsample( + data: AnnData | np.ndarray | CSMatrix, + fraction: float | None = None, + *, + n_obs: int | None = None, + random_state: _LegacyRandom = 0, + copy: bool = False, +) -> AnnData | tuple[np.ndarray | CSMatrix, NDArray[np.int64]] | None: + """\ + Subsample to a fraction of the number of observations. + + .. deprecated:: 1.11.0 + + Use :func:`~scanpy.pp.sample` instead. + + Parameters + ---------- + data + The (annotated) data matrix of shape `n_obs` × `n_vars`. + Rows correspond to cells and columns to genes. + fraction + Subsample to this `fraction` of the number of observations. + n_obs + Subsample to this number of observations. + random_state + Random seed to change subsampling. + copy + If an :class:`~anndata.AnnData` is passed, + determines whether a copy is returned. + + Returns + ------- + Returns `X[obs_indices], obs_indices` if data is array-like, otherwise + subsamples the passed :class:`~anndata.AnnData` (`copy == False`) or + returns a subsampled copy of it (`copy == True`). + """ + + rng = _legacy_numpy_gen(random_state) + return sample( + data=data, fraction=fraction, n=n_obs, rng=rng, copy=copy, replace=False, axis=0 + ) diff --git a/src/scanpy/preprocessing/_simple.py b/src/scanpy/preprocessing/_simple.py index eaf9648690..29c267c3f4 100644 --- a/src/scanpy/preprocessing/_simple.py +++ b/src/scanpy/preprocessing/_simple.py @@ -8,20 +8,21 @@ import warnings from functools import singledispatch from itertools import repeat -from typing import TYPE_CHECKING, TypeVar +from typing import TYPE_CHECKING, TypeVar, overload import numba import numpy as np from anndata import AnnData from pandas.api.types import CategoricalDtype -from scipy.sparse import csr_matrix, issparse, isspmatrix_csr, spmatrix +from scipy.sparse import csc_matrix, csr_matrix, issparse, isspmatrix_csr, spmatrix from sklearn.utils import check_array, sparsefuncs from .. import logging as logg -from .._compat import deprecated, njit, old_positionals +from .._compat import DaskArray, deprecated, njit, old_positionals from .._settings import settings as sett from .._utils import ( _check_array_function_arguments, + _resolve_axis, axis_sum, is_backed_type, raise_not_implemented_error_if_backed_type, @@ -33,15 +34,11 @@ from ._distributed import materialize_as_ndarray from ._utils import _to_dense -# install dask if available try: import dask.array as da except ImportError: da = None -# backwards compat -from ._deprecated.highly_variable_genes import filter_genes_dispersion # noqa: F401 - if TYPE_CHECKING: from collections.abc import Collection, Iterable, Sequence from numbers import Number @@ -50,7 +47,13 @@ import pandas as pd from numpy.typing import NDArray - from .._compat import DaskArray, _LegacyRandom + from .._compat import _LegacyRandom + from .._utils import RNGLike, SeedLike + + +CSMatrix = csr_matrix | csc_matrix + +A = TypeVar("A", bound=np.ndarray | CSMatrix | DaskArray) @old_positionals( @@ -825,17 +828,51 @@ def _regress_out_chunk( return np.vstack(responses_chunk_list) -@old_positionals("n_obs", "random_state", "copy") -def subsample( - data: AnnData | np.ndarray | spmatrix, +@overload +def sample( + data: AnnData, fraction: float | None = None, *, - n_obs: int | None = None, - random_state: _LegacyRandom = 0, + n: int | None = None, + rng: RNGLike | SeedLike | None = 0, + copy: Literal[False] = False, + replace: bool = False, + axis: Literal["obs", 0, "var", 1] = "obs", +) -> None: ... +@overload +def sample( + data: AnnData, + fraction: float | None = None, + *, + n: int | None = None, + rng: RNGLike | SeedLike | None = None, + copy: Literal[True], + replace: bool = False, + axis: Literal["obs", 0, "var", 1] = "obs", +) -> AnnData: ... +@overload +def sample( + data: A, + fraction: float | None = None, + *, + n: int | None = None, + rng: RNGLike | SeedLike | None = None, copy: bool = False, -) -> AnnData | tuple[np.ndarray | spmatrix, NDArray[np.int64]] | None: + replace: bool = False, + axis: Literal["obs", 0, "var", 1] = "obs", +) -> tuple[A, NDArray[np.int64]]: ... +def sample( + data: AnnData | np.ndarray | CSMatrix | DaskArray, + fraction: float | None = None, + *, + n: int | None = None, + rng: RNGLike | SeedLike | None = None, + copy: bool = False, + replace: bool = False, + axis: Literal["obs", 0, "var", 1] = "obs", +) -> AnnData | None | tuple[np.ndarray | CSMatrix | DaskArray, NDArray[np.int64]]: """\ - Subsample to a fraction of the number of observations. + Sample observations or variables with or without replacement. Parameters ---------- @@ -843,49 +880,81 @@ def subsample( The (annotated) data matrix of shape `n_obs` × `n_vars`. Rows correspond to cells and columns to genes. fraction - Subsample to this `fraction` of the number of observations. - n_obs - Subsample to this number of observations. + Sample to this `fraction` of the number of observations or variables. + This can be larger than 1.0, if `replace=True`. + See `axis` and `replace`. + n + Sample to this number of observations or variables. See `axis`. random_state Random seed to change subsampling. copy If an :class:`~anndata.AnnData` is passed, determines whether a copy is returned. + replace + If True, samples are drawn with replacement. + axis + Sample `obs`\\ ervations (axis 0) or `var`\\ iables (axis 1). Returns ------- - Returns `X[obs_indices], obs_indices` if data is array-like, otherwise - subsamples the passed :class:`~anndata.AnnData` (`copy == False`) or - returns a subsampled copy of it (`copy == True`). + If `isinstance(data, AnnData)` and `copy=False`, + this function returns `None`. Otherwise: + + `data[indices, :]` | `data[:, indices]` (depending on `axis`) + If `data` is array-like or `copy=True`, returns the subset. + `indices` : numpy.ndarray + If `data` is array-like, also returns the indices into the original. """ - np.random.seed(random_state) - old_n_obs = data.n_obs if isinstance(data, AnnData) else data.shape[0] - if n_obs is not None: - new_n_obs = n_obs - elif fraction is not None: - if fraction > 1 or fraction < 0: - raise ValueError(f"`fraction` needs to be within [0, 1], not {fraction}") - new_n_obs = int(fraction * old_n_obs) - logg.debug(f"... subsampled to {new_n_obs} data points") - else: - raise ValueError("Either pass `n_obs` or `fraction`.") - obs_indices = np.random.choice(old_n_obs, size=new_n_obs, replace=False) - if isinstance(data, AnnData): - if data.isbacked: - if copy: - return data[obs_indices].to_memory() - else: - raise NotImplementedError( - "Inplace subsampling is not implemented for backed objects." - ) + # parameter validation + if not copy and isinstance(data, AnnData) and data.isbacked: + msg = "Inplace sampling (`copy=False`) is not implemented for backed objects." + raise NotImplementedError(msg) + axis, axis_name = _resolve_axis(axis) + old_n = data.shape[axis] + match (fraction, n): + case (None, None): + msg = "Either `fraction` or `n` must be set." + raise TypeError(msg) + case (None, _): + pass + case (_, None): + if fraction < 0: + msg = f"`{fraction=}` needs to be nonnegative." + raise ValueError(msg) + if not replace and fraction > 1: + msg = f"If `replace=False`, `{fraction=}` needs to be within [0, 1]." + raise ValueError(msg) + n = int(fraction * old_n) + logg.debug(f"... sampled to {n} {axis_name}") + case _: + msg = "Providing both `fraction` and `n` is not allowed." + raise TypeError(msg) + del fraction + + # actually do subsampling + rng = np.random.default_rng(rng) + indices = rng.choice(old_n, size=n, replace=replace) + + # overload 1: inplace AnnData subset + if not copy and isinstance(data, AnnData): + if axis_name == "obs": + data._inplace_subset_obs(indices) else: - if copy: - return data[obs_indices].copy() - else: - data._inplace_subset_obs(obs_indices) - else: - X = data - return X[obs_indices], obs_indices + data._inplace_subset_var(indices) + return None + + subset = data[indices] if axis_name == "obs" else data[:, indices] + + # overload 2: copy AnnData subset + if copy and isinstance(data, AnnData): + assert isinstance(subset, AnnData) + return subset.to_memory() if data.isbacked else subset.copy() + + # overload 3: return array and indices + assert isinstance(subset, np.ndarray | CSMatrix | DaskArray), type(subset) + if copy: + subset = subset.copy() + return subset, indices @renamed_arg("target_counts", "counts_per_cell") diff --git a/tests/test_package_structure.py b/tests/test_package_structure.py index 834c06d8b4..3541c561a5 100644 --- a/tests/test_package_structure.py +++ b/tests/test_package_structure.py @@ -138,6 +138,7 @@ class ExpectedSig(TypedDict): copy_sigs["sc.pp.filter_cells"] = None # unclear `inplace` situation copy_sigs["sc.pp.filter_genes"] = None # unclear `inplace` situation copy_sigs["sc.pp.subsample"] = None # returns indices along matrix +copy_sigs["sc.pp.sample"] = None # returns indices along matrix # partial exceptions: “data” instead of “adata” copy_sigs["sc.pp.log1p"]["first_name"] = "data" copy_sigs["sc.pp.normalize_per_cell"]["first_name"] = "data" diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index b8f5115b01..36283e7ed0 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -1,7 +1,10 @@ from __future__ import annotations +import warnings +from importlib.util import find_spec from itertools import product from pathlib import Path +from typing import TYPE_CHECKING import numpy as np import pandas as pd @@ -22,6 +25,13 @@ from testing.scanpy._helpers.data import pbmc3k, pbmc68k_reduced from testing.scanpy._pytest.params import ARRAY_TYPES +if TYPE_CHECKING: + from collections.abc import Callable + from typing import Any, Literal + + CSMatrix = sp.csc_matrix | sp.csr_matrix + + HERE = Path(__file__).parent DATA_PATH = HERE / "_data" @@ -134,34 +144,128 @@ def test_normalize_per_cell(): assert adata.X.sum(axis=1).tolist() == adata_sparse.X.sum(axis=1).A1.tolist() -def test_subsample(): - adata = AnnData(np.ones((200, 10))) - sc.pp.subsample(adata, n_obs=40) - assert adata.n_obs == 40 - sc.pp.subsample(adata, fraction=0.1) - assert adata.n_obs == 4 +@pytest.mark.parametrize("array_type", ARRAY_TYPES) +@pytest.mark.parametrize("which", ["copy", "inplace", "array"]) +@pytest.mark.parametrize( + ("axis", "fraction", "n", "replace", "expected"), + [ + pytest.param(0, None, 40, False, 40, id="obs-40-no_replace"), + pytest.param(0, 0.1, None, False, 20, id="obs-0.1-no_replace"), + pytest.param(0, None, 201, True, 201, id="obs-201-replace"), + pytest.param(0, None, 1, True, 1, id="obs-1-replace"), + pytest.param(1, None, 10, False, 10, id="var-10-no_replace"), + pytest.param(1, None, 11, True, 11, id="var-11-replace"), + pytest.param(1, 2.0, None, True, 20, id="var-2.0-replace"), + ], +) +def test_sample( + *, + array_type: Callable[[np.ndarray], np.ndarray | CSMatrix], + which: Literal["copy", "inplace", "array"], + axis: Literal[0, 1], + fraction: float | None, + n: int | None, + replace: bool, + expected: int, +): + adata = AnnData(array_type(np.ones((200, 10)))) + + # ignoring this warning declaratively is a pain so do it here + if find_spec("dask"): + import dask.array as da + + warnings.filterwarnings("ignore", category=da.PerformanceWarning) + # can’t guarantee that duplicates are drawn when `replace=True`, + # so we just ignore the warning instead using `with pytest.warns(...)` + warnings.filterwarnings( + "ignore" if replace else "error", r".*names are not unique", UserWarning + ) + rv = sc.pp.sample( + adata.X if which == "array" else adata, + fraction, + n=n, + replace=replace, + axis=axis, + # `copy` only effects AnnData inputs + copy=dict(copy=True, inplace=False, array=False)[which], + ) + match which: + case "copy": + subset = rv + assert rv is not adata + assert adata.shape == (200, 10) + case "inplace": + subset = adata + assert rv is None + case "array": + subset, indices = rv + assert len(indices) == expected + assert adata.shape == (200, 10) + case _: + pytest.fail(f"Unknown `{which=}`") -def test_subsample_copy(): + assert subset.shape == ((expected, 10) if axis == 0 else (200, expected)) + + +@pytest.mark.parametrize( + ("args", "exc", "pattern"), + [ + pytest.param( + dict(), TypeError, r"Either `fraction` or `n` must be set", id="empty" + ), + pytest.param( + dict(n=10, fraction=0.2), + TypeError, + r"Providing both `fraction` and `n` is not allowed", + id="both", + ), + pytest.param( + dict(fraction=2), + ValueError, + r"If `replace=False`, `fraction=2` needs to be", + id="frac>1", + ), + pytest.param( + dict(fraction=-0.3), + ValueError, + r"`fraction=-0\.3` needs to be nonnegative", + id="frac<0", + ), + ], +) +def test_sample_error(args: dict[str, Any], exc: type[Exception], pattern: str): adata = AnnData(np.ones((200, 10))) - assert sc.pp.subsample(adata, n_obs=40, copy=True).shape == (40, 10) - assert sc.pp.subsample(adata, fraction=0.1, copy=True).shape == (20, 10) + with pytest.raises(exc, match=pattern): + sc.pp.sample(adata, **args) -def test_subsample_copy_backed(tmp_path): - A = np.random.rand(200, 10).astype(np.float32) - adata_m = AnnData(A.copy()) - adata_d = AnnData(A.copy()) - filename = tmp_path / "test.h5ad" - adata_d.filename = filename - # This should not throw an error - assert sc.pp.subsample(adata_d, n_obs=40, copy=True).shape == (40, 10) +def test_sample_backwards_compat(): + expected = np.array( + [26, 86, 2, 55, 75, 93, 16, 73, 54, 95, 53, 92, 78, 13, 7, 30, 22, 24, 33, 8] + ) + legacy_result, indices = sc.pp.subsample(np.arange(100), n_obs=20) + assert np.array_equal(indices, legacy_result), "arange choices should match indices" + assert np.array_equal(legacy_result, expected) + + +def test_sample_copy_backed(tmp_path): + adata_m = AnnData(np.random.rand(200, 10).astype(np.float32)) + adata_d = adata_m.copy() + adata_d.filename = tmp_path / "test.h5ad" + + assert sc.pp.sample(adata_d, n=40, copy=True).shape == (40, 10) np.testing.assert_array_equal( - sc.pp.subsample(adata_m, n_obs=40, copy=True).X, - sc.pp.subsample(adata_d, n_obs=40, copy=True).X, + sc.pp.sample(adata_m, n=40, copy=True, rng=0).X, + sc.pp.sample(adata_d, n=40, copy=True, rng=0).X, ) + + +def test_sample_copy_backed_error(tmp_path): + adata_d = AnnData(np.random.rand(200, 10).astype(np.float32)) + adata_d.filename = tmp_path / "test.h5ad" with pytest.raises(NotImplementedError): - sc.pp.subsample(adata_d, n_obs=40, copy=False) + sc.pp.sample(adata_d, n=40, copy=False) @pytest.mark.parametrize("array_type", ARRAY_TYPES) diff --git a/tests/test_utils.py b/tests/test_utils.py index f8a38a5f9d..81369a6938 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,6 +2,7 @@ from operator import mul, truediv from types import ModuleType +from typing import TYPE_CHECKING import numpy as np import pytest @@ -9,7 +10,7 @@ from packaging.version import Version from scipy.sparse import csr_matrix, issparse -from scanpy._compat import DaskArray, pkg_version +from scanpy._compat import DaskArray, _legacy_numpy_gen, pkg_version from scanpy._utils import ( axis_mul_or_truediv, axis_sum, @@ -26,6 +27,9 @@ ARRAY_TYPES_SPARSE_DASK_UNSUPPORTED, ) +if TYPE_CHECKING: + from typing import Any + def test_descend_classes_and_funcs(): # create module hierarchy @@ -247,3 +251,39 @@ def test_is_constant_dask(request: pytest.FixtureRequest, axis, expected, block_ x = da.from_array(np.array(x_data), chunks=2).map_blocks(block_type) result = is_constant(x, axis=axis).compute() np.testing.assert_array_equal(expected, result) + + +@pytest.mark.parametrize("seed", [0, 1, 1256712675]) +@pytest.mark.parametrize("pass_seed", [True, False], ids=["pass_seed", "set_seed"]) +@pytest.mark.parametrize("func", ["choice"]) +def test_legacy_numpy_gen(*, seed: int, pass_seed: bool, func: str): + np.random.seed(seed) + state_before = np.random.get_state(legacy=False) + + arrs: dict[bool, np.ndarray] = {} + states_after: dict[bool, dict[str, Any]] = {} + for direct in [True, False]: + if not pass_seed: + np.random.seed(seed) + arrs[direct] = _mk_random(func, direct=direct, seed=seed if pass_seed else None) + states_after[direct] = np.random.get_state(legacy=False) + + np.testing.assert_array_equal(arrs[True], arrs[False]) + np.testing.assert_equal( + *states_after.values(), err_msg="both should affect global state the same" + ) + # they should affect the global state + with pytest.raises(AssertionError): + np.testing.assert_equal(states_after[True], state_before) + + +def _mk_random(func: str, *, direct: bool, seed: int | None) -> np.ndarray: + if direct and seed is not None: + np.random.seed(seed) + gen = np.random if direct else _legacy_numpy_gen(seed) + match func: + case "choice": + arr = np.arange(1000) + return gen.choice(arr, size=(100, 100)) + case _: + pytest.fail(f"Unknown {func=}")