diff --git a/anndata/tests/helpers.py b/anndata/tests/helpers.py index 316fc991e..428b3c21b 100644 --- a/anndata/tests/helpers.py +++ b/anndata/tests/helpers.py @@ -12,6 +12,7 @@ import numpy as np import pandas as pd import pytest +import zarr from pandas.api.types import is_numeric_dtype from scipy import sparse @@ -743,3 +744,22 @@ def shares_memory_sparse(x, y): marks=pytest.mark.gpu, ), ] + + +class AccessTrackingStore(zarr.DirectoryStore): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._access_count = {} + + def __getitem__(self, key): + for tracked in self._access_count: + if tracked in key: + self._access_count[tracked] += 1 + return super().__getitem__(key) + + def get_access_count(self, key): + return self._access_count[key] + + def set_key_trackers(self, keys_to_track): + for k in keys_to_track: + self._access_count[k] = 0 diff --git a/anndata/tests/test_backed_sparse.py b/anndata/tests/test_backed_sparse.py index 0643cbcaf..f5e8be3a7 100644 --- a/anndata/tests/test_backed_sparse.py +++ b/anndata/tests/test_backed_sparse.py @@ -13,7 +13,7 @@ from anndata._core.anndata import AnnData from anndata._core.sparse_dataset import sparse_dataset from anndata.experimental import read_dispatched -from anndata.tests.helpers import assert_equal, subset_func +from anndata.tests.helpers import AccessTrackingStore, assert_equal, subset_func if TYPE_CHECKING: from pathlib import Path @@ -164,6 +164,34 @@ def test_dataset_append_disk( assert_equal(fromdisk, frommem) +@pytest.mark.parametrize( + ["sparse_format"], + [ + pytest.param(sparse.csr_matrix), + pytest.param(sparse.csc_matrix), + ], +) +def test_indptr_cache( + tmp_path: Path, + sparse_format: Callable[[ArrayLike], sparse.spmatrix], +): + path = tmp_path / "test.zarr" # diskfmt is either h5ad or zarr + a = sparse_format(sparse.random(10, 10)) + store = AccessTrackingStore(path) + store.set_key_trackers("indptr") + f = zarr.open_group(store, "a") + ad._io.specs.write_elem(f, "a", a) + a_disk = sparse_dataset(f["a"]) + a_disk[:1] + a_disk[3:5] + a_disk[6:7] + a_disk[8:9] + a_disk[...] + assert ( + store.get_access_count("indptr") == 3 + ) # one each for .zarray, .zattrs, and actual access + + @pytest.mark.parametrize( ["sparse_format", "a_shape", "b_shape"], [ @@ -198,6 +226,21 @@ def test_wrong_shape( a_disk.append(b_disk) +def test_reset_group(tmp_path: Path): + path = tmp_path / "test.zarr" # diskfmt is either h5ad or zarr + base = sparse.random(100, 100, format="csr") + + if diskfmt == "zarr": + f = zarr.open_group(path, "a") + else: + f = h5py.File(path, "a") + + ad._io.specs.write_elem(f, "base", base) + disk_mtx = sparse_dataset(f["base"]) + with pytest.raises(TypeError): + disk_mtx.group = f + + def test_wrong_formats(tmp_path: Path, diskfmt: Literal["h5ad", "zarr"]): path = ( tmp_path / f"test.{diskfmt.replace('ad', '')}"