Skip to content

Commit

Permalink
Backport PR #1266: (fix): cache indptr for backed sparse matrices (#…
Browse files Browse the repository at this point in the history
…1296)

Co-authored-by: Ilan Gold <[email protected]>
  • Loading branch information
meeseeksmachine and ilan-gold authored Jan 11, 2024
1 parent 770e97b commit 1590770
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 7 deletions.
33 changes: 27 additions & 6 deletions anndata/_core/sparse_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import collections.abc as cabc
import warnings
from abc import ABC
from functools import cached_property
from itertools import accumulate, chain
from math import floor
from pathlib import Path
Expand All @@ -41,6 +42,8 @@
if TYPE_CHECKING:
from collections.abc import Iterable, Sequence

from .._types import GroupStorageType


class BackedFormat(NamedTuple):
format: str
Expand Down Expand Up @@ -138,7 +141,7 @@ def _offsets(
def _get_contiguous_compressed_slice(
self, s: slice
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
new_indptr = self.indptr[s.start : s.stop + 1]
new_indptr = self.indptr[s.start : s.stop + 1].copy()

start = new_indptr[0]
stop = new_indptr[-1]
Expand Down Expand Up @@ -325,13 +328,26 @@ def _get_group_format(group) -> str:
class BaseCompressedSparseDataset(ABC):
"""Analogous to :class:`h5py.Dataset <h5py:Dataset>` or `zarr.Array`, but for sparse matrices."""

def __init__(self, group: h5py.Group | ZarrGroup):
_group: GroupStorageType

def __init__(self, group: GroupStorageType):
type(self)._check_group_format(group)
self.group = group
self._group = group

shape: tuple[int, int]
"""Shape of the matrix."""

@property
def group(self):
"""The group underlying the backed matrix."""
return self._group

@group.setter
def group(self, val):
raise AttributeError(
f"Do not reset group on a {type(self)} with {val}. Instead use `sparse_dataset` to make a new class."
)

@property
def backend(self) -> Literal["zarr", "hdf5"]:
if isinstance(self.group, ZarrGroup):
Expand Down Expand Up @@ -489,20 +505,25 @@ def append(self, sparse_matrix: ss.spmatrix):
indices.resize((orig_data_size + sparse_matrix.indices.shape[0],))
indices[orig_data_size:] = sparse_matrix.indices

@cached_property
def indptr(self) -> np.ndarray:
arr = self.group["indptr"][...]
return arr

def _to_backed(self) -> BackedSparseMatrix:
format_class = get_backed_class(self.format)
mtx = format_class(self.shape, dtype=self.dtype)
mtx.data = self.group["data"]
mtx.indices = self.group["indices"]
mtx.indptr = self.group["indptr"][:]
mtx.indptr = self.indptr
return mtx

def to_memory(self) -> ss.spmatrix:
format_class = get_memory_class(self.format)
mtx = format_class(self.shape, dtype=self.dtype)
mtx.data = self.group["data"][...]
mtx.indices = self.group["indices"][...]
mtx.indptr = self.group["indptr"][...]
mtx.indptr = self.indptr
return mtx


Expand Down Expand Up @@ -530,7 +551,7 @@ class CSCDataset(BaseCompressedSparseDataset):
format = "csc"


def sparse_dataset(group: ZarrGroup | H5Group) -> CSRDataset | CSCDataset:
def sparse_dataset(group: GroupStorageType) -> CSRDataset | CSCDataset:
"""Generates a backed mode-compatible sparse dataset class.
Parameters
Expand Down
20 changes: 20 additions & 0 deletions anndata/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import numpy as np
import pandas as pd
import pytest
import zarr
from pandas.api.types import is_numeric_dtype
from scipy import sparse

Expand Down Expand Up @@ -743,3 +744,22 @@ def shares_memory_sparse(x, y):
marks=pytest.mark.gpu,
),
]


class AccessTrackingStore(zarr.DirectoryStore):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._access_count = {}

def __getitem__(self, key):
for tracked in self._access_count:
if tracked in key:
self._access_count[tracked] += 1
return super().__getitem__(key)

def get_access_count(self, key):
return self._access_count[key]

def set_key_trackers(self, keys_to_track):
for k in keys_to_track:
self._access_count[k] = 0
45 changes: 44 additions & 1 deletion anndata/tests/test_backed_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from anndata._core.anndata import AnnData
from anndata._core.sparse_dataset import sparse_dataset
from anndata.experimental import read_dispatched
from anndata.tests.helpers import assert_equal, subset_func
from anndata.tests.helpers import AccessTrackingStore, assert_equal, subset_func

if TYPE_CHECKING:
from pathlib import Path
Expand Down Expand Up @@ -199,6 +199,34 @@ def test_dataset_append_disk(
assert_equal(fromdisk, frommem)


@pytest.mark.parametrize(
["sparse_format"],
[
pytest.param(sparse.csr_matrix),
pytest.param(sparse.csc_matrix),
],
)
def test_indptr_cache(
tmp_path: Path,
sparse_format: Callable[[ArrayLike], sparse.spmatrix],
):
path = tmp_path / "test.zarr" # diskfmt is either h5ad or zarr
a = sparse_format(sparse.random(10, 10))
f = zarr.open_group(path, "a")
ad._io.specs.write_elem(f, "X", a)
store = AccessTrackingStore(path)
store.set_key_trackers(["X/indptr"])
f = zarr.open_group(store, "a")
a_disk = sparse_dataset(f["X"])
a_disk[:1]
a_disk[3:5]
a_disk[6:7]
a_disk[8:9]
assert (
store.get_access_count("X/indptr") == 2
) # one each for .zarray and actual access


@pytest.mark.parametrize(
["sparse_format", "a_shape", "b_shape"],
[
Expand Down Expand Up @@ -233,6 +261,21 @@ def test_wrong_shape(
a_disk.append(b_disk)


def test_reset_group(tmp_path: Path):
path = tmp_path / "test.zarr" # diskfmt is either h5ad or zarr
base = sparse.random(100, 100, format="csr")

if diskfmt == "zarr":
f = zarr.open_group(path, "a")
else:
f = h5py.File(path, "a")

ad._io.specs.write_elem(f, "base", base)
disk_mtx = sparse_dataset(f["base"])
with pytest.raises(AttributeError):
disk_mtx.group = f


def test_wrong_formats(tmp_path: Path, diskfmt: Literal["h5ad", "zarr"]):
path = (
tmp_path / f"test.{diskfmt.replace('ad', '')}"
Expand Down
1 change: 1 addition & 0 deletions docs/release-notes/0.10.5.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@
```{rubric} Performance
```

* `BaseCompressedSparseDataset`'s `indptr` is cached {pr}`1266` {user}`ilan-gold`
* Improved performance when indexing backed sparse matrices with boolean masks along their major axis {pr}`1233` {user}`ilan-gold`

0 comments on commit 1590770

Please sign in to comment.