From 1590770d3731db8fe3632306a0ad616990c06d36 Mon Sep 17 00:00:00 2001 From: "Lumberbot (aka Jack)" <39504233+meeseeksmachine@users.noreply.github.com> Date: Thu, 11 Jan 2024 15:58:57 +0100 Subject: [PATCH] Backport PR #1266: (fix): cache `indptr` for backed sparse matrices (#1296) Co-authored-by: Ilan Gold --- anndata/_core/sparse_dataset.py | 33 +++++++++++++++++---- anndata/tests/helpers.py | 20 +++++++++++++ anndata/tests/test_backed_sparse.py | 45 ++++++++++++++++++++++++++++- docs/release-notes/0.10.5.md | 1 + 4 files changed, 92 insertions(+), 7 deletions(-) diff --git a/anndata/_core/sparse_dataset.py b/anndata/_core/sparse_dataset.py index c2f746984..19f6f05d0 100644 --- a/anndata/_core/sparse_dataset.py +++ b/anndata/_core/sparse_dataset.py @@ -15,6 +15,7 @@ import collections.abc as cabc import warnings from abc import ABC +from functools import cached_property from itertools import accumulate, chain from math import floor from pathlib import Path @@ -41,6 +42,8 @@ if TYPE_CHECKING: from collections.abc import Iterable, Sequence + from .._types import GroupStorageType + class BackedFormat(NamedTuple): format: str @@ -138,7 +141,7 @@ def _offsets( def _get_contiguous_compressed_slice( self, s: slice ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - new_indptr = self.indptr[s.start : s.stop + 1] + new_indptr = self.indptr[s.start : s.stop + 1].copy() start = new_indptr[0] stop = new_indptr[-1] @@ -325,13 +328,26 @@ def _get_group_format(group) -> str: class BaseCompressedSparseDataset(ABC): """Analogous to :class:`h5py.Dataset ` or `zarr.Array`, but for sparse matrices.""" - def __init__(self, group: h5py.Group | ZarrGroup): + _group: GroupStorageType + + def __init__(self, group: GroupStorageType): type(self)._check_group_format(group) - self.group = group + self._group = group shape: tuple[int, int] """Shape of the matrix.""" + @property + def group(self): + """The group underlying the backed matrix.""" + return self._group + + @group.setter + def group(self, val): + raise AttributeError( + f"Do not reset group on a {type(self)} with {val}. Instead use `sparse_dataset` to make a new class." + ) + @property def backend(self) -> Literal["zarr", "hdf5"]: if isinstance(self.group, ZarrGroup): @@ -489,12 +505,17 @@ def append(self, sparse_matrix: ss.spmatrix): indices.resize((orig_data_size + sparse_matrix.indices.shape[0],)) indices[orig_data_size:] = sparse_matrix.indices + @cached_property + def indptr(self) -> np.ndarray: + arr = self.group["indptr"][...] + return arr + def _to_backed(self) -> BackedSparseMatrix: format_class = get_backed_class(self.format) mtx = format_class(self.shape, dtype=self.dtype) mtx.data = self.group["data"] mtx.indices = self.group["indices"] - mtx.indptr = self.group["indptr"][:] + mtx.indptr = self.indptr return mtx def to_memory(self) -> ss.spmatrix: @@ -502,7 +523,7 @@ def to_memory(self) -> ss.spmatrix: mtx = format_class(self.shape, dtype=self.dtype) mtx.data = self.group["data"][...] mtx.indices = self.group["indices"][...] - mtx.indptr = self.group["indptr"][...] + mtx.indptr = self.indptr return mtx @@ -530,7 +551,7 @@ class CSCDataset(BaseCompressedSparseDataset): format = "csc" -def sparse_dataset(group: ZarrGroup | H5Group) -> CSRDataset | CSCDataset: +def sparse_dataset(group: GroupStorageType) -> CSRDataset | CSCDataset: """Generates a backed mode-compatible sparse dataset class. Parameters diff --git a/anndata/tests/helpers.py b/anndata/tests/helpers.py index 316fc991e..428b3c21b 100644 --- a/anndata/tests/helpers.py +++ b/anndata/tests/helpers.py @@ -12,6 +12,7 @@ import numpy as np import pandas as pd import pytest +import zarr from pandas.api.types import is_numeric_dtype from scipy import sparse @@ -743,3 +744,22 @@ def shares_memory_sparse(x, y): marks=pytest.mark.gpu, ), ] + + +class AccessTrackingStore(zarr.DirectoryStore): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._access_count = {} + + def __getitem__(self, key): + for tracked in self._access_count: + if tracked in key: + self._access_count[tracked] += 1 + return super().__getitem__(key) + + def get_access_count(self, key): + return self._access_count[key] + + def set_key_trackers(self, keys_to_track): + for k in keys_to_track: + self._access_count[k] = 0 diff --git a/anndata/tests/test_backed_sparse.py b/anndata/tests/test_backed_sparse.py index 128b9a773..414bcdfa2 100644 --- a/anndata/tests/test_backed_sparse.py +++ b/anndata/tests/test_backed_sparse.py @@ -12,7 +12,7 @@ from anndata._core.anndata import AnnData from anndata._core.sparse_dataset import sparse_dataset from anndata.experimental import read_dispatched -from anndata.tests.helpers import assert_equal, subset_func +from anndata.tests.helpers import AccessTrackingStore, assert_equal, subset_func if TYPE_CHECKING: from pathlib import Path @@ -199,6 +199,34 @@ def test_dataset_append_disk( assert_equal(fromdisk, frommem) +@pytest.mark.parametrize( + ["sparse_format"], + [ + pytest.param(sparse.csr_matrix), + pytest.param(sparse.csc_matrix), + ], +) +def test_indptr_cache( + tmp_path: Path, + sparse_format: Callable[[ArrayLike], sparse.spmatrix], +): + path = tmp_path / "test.zarr" # diskfmt is either h5ad or zarr + a = sparse_format(sparse.random(10, 10)) + f = zarr.open_group(path, "a") + ad._io.specs.write_elem(f, "X", a) + store = AccessTrackingStore(path) + store.set_key_trackers(["X/indptr"]) + f = zarr.open_group(store, "a") + a_disk = sparse_dataset(f["X"]) + a_disk[:1] + a_disk[3:5] + a_disk[6:7] + a_disk[8:9] + assert ( + store.get_access_count("X/indptr") == 2 + ) # one each for .zarray and actual access + + @pytest.mark.parametrize( ["sparse_format", "a_shape", "b_shape"], [ @@ -233,6 +261,21 @@ def test_wrong_shape( a_disk.append(b_disk) +def test_reset_group(tmp_path: Path): + path = tmp_path / "test.zarr" # diskfmt is either h5ad or zarr + base = sparse.random(100, 100, format="csr") + + if diskfmt == "zarr": + f = zarr.open_group(path, "a") + else: + f = h5py.File(path, "a") + + ad._io.specs.write_elem(f, "base", base) + disk_mtx = sparse_dataset(f["base"]) + with pytest.raises(AttributeError): + disk_mtx.group = f + + def test_wrong_formats(tmp_path: Path, diskfmt: Literal["h5ad", "zarr"]): path = ( tmp_path / f"test.{diskfmt.replace('ad', '')}" diff --git a/docs/release-notes/0.10.5.md b/docs/release-notes/0.10.5.md index 0aed44ae4..5310ad681 100644 --- a/docs/release-notes/0.10.5.md +++ b/docs/release-notes/0.10.5.md @@ -11,4 +11,5 @@ ```{rubric} Performance ``` +* `BaseCompressedSparseDataset`'s `indptr` is cached {pr}`1266` {user}`ilan-gold` * Improved performance when indexing backed sparse matrices with boolean masks along their major axis {pr}`1233` {user}`ilan-gold`