Skip to content

Commit

Permalink
(chore): add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ilan-gold committed Dec 14, 2023
1 parent 6622ba5 commit 9b48a9d
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 1 deletion.
20 changes: 20 additions & 0 deletions anndata/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import numpy as np
import pandas as pd
import pytest
import zarr
from pandas.api.types import is_numeric_dtype
from scipy import sparse

Expand Down Expand Up @@ -743,3 +744,22 @@ def shares_memory_sparse(x, y):
marks=pytest.mark.gpu,
),
]


class AccessTrackingStore(zarr.DirectoryStore):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._access_count = {}

def __getitem__(self, key):
for tracked in self._access_count:
if tracked in key:
self._access_count[tracked] += 1
return super().__getitem__(key)

def get_access_count(self, key):
return self._access_count[key]

def set_key_trackers(self, keys_to_track):
for k in keys_to_track:
self._access_count[k] = 0
45 changes: 44 additions & 1 deletion anndata/tests/test_backed_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from anndata._core.anndata import AnnData
from anndata._core.sparse_dataset import sparse_dataset
from anndata.experimental import read_dispatched
from anndata.tests.helpers import assert_equal, subset_func
from anndata.tests.helpers import AccessTrackingStore, assert_equal, subset_func

if TYPE_CHECKING:
from pathlib import Path
Expand Down Expand Up @@ -164,6 +164,34 @@ def test_dataset_append_disk(
assert_equal(fromdisk, frommem)


@pytest.mark.parametrize(
["sparse_format"],
[
pytest.param(sparse.csr_matrix),
pytest.param(sparse.csc_matrix),
],
)
def test_indptr_cache(
tmp_path: Path,
sparse_format: Callable[[ArrayLike], sparse.spmatrix],
):
path = tmp_path / "test.zarr" # diskfmt is either h5ad or zarr
a = sparse_format(sparse.random(10, 10))
store = AccessTrackingStore(path)
store.set_key_trackers("indptr")
f = zarr.open_group(store, "a")
ad._io.specs.write_elem(f, "a", a)
a_disk = sparse_dataset(f["a"])
a_disk[:1]
a_disk[3:5]
a_disk[6:7]
a_disk[8:9]
a_disk[...]
assert (
store.get_access_count("indptr") == 3
) # one each for .zarray, .zattrs, and actual access


@pytest.mark.parametrize(
["sparse_format", "a_shape", "b_shape"],
[
Expand Down Expand Up @@ -198,6 +226,21 @@ def test_wrong_shape(
a_disk.append(b_disk)


def test_reset_group(tmp_path: Path):
path = tmp_path / "test.zarr" # diskfmt is either h5ad or zarr
base = sparse.random(100, 100, format="csr")

if diskfmt == "zarr":
f = zarr.open_group(path, "a")
else:
f = h5py.File(path, "a")

ad._io.specs.write_elem(f, "base", base)
disk_mtx = sparse_dataset(f["base"])
with pytest.raises(TypeError):
disk_mtx.group = f


def test_wrong_formats(tmp_path: Path, diskfmt: Literal["h5ad", "zarr"]):
path = (
tmp_path / f"test.{diskfmt.replace('ad', '')}"
Expand Down

0 comments on commit 9b48a9d

Please sign in to comment.