Skip to content

Commit

Permalink
Merge branch 'main' into ig/backed_sparse_indexing_performance
Browse files Browse the repository at this point in the history
  • Loading branch information
ivirshup committed Jan 10, 2024
2 parents 8d9a3b3 + 522d7ea commit 27f9f2c
Show file tree
Hide file tree
Showing 20 changed files with 210 additions and 149 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# Caches for compiled and downloaded files
__pycache__/
/*cache/
/node_modules/
/data/

# Distribution / packaging
Expand Down
15 changes: 7 additions & 8 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
repos:
- repo: https://github.com/psf/black
rev: 23.11.0
hooks:
- id: black
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: "v0.1.7"
rev: v0.1.11
hooks:
- id: ruff
types_or: [python, pyi, jupyter]
args: ["--fix"]
- id: ruff-format
types_or: [python, pyi, jupyter]
- repo: https://github.com/pre-commit/mirrors-prettier
rev: v4.0.0-alpha.4
rev: v4.0.0-alpha.8
hooks:
- id: prettier
exclude_types:
- markdown
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
Expand All @@ -26,7 +26,6 @@ repos:
- id: detect-private-key
- id: no-commit-to-branch
args: ["--branch=main"]

- repo: https://github.com/codespell-project/codespell
rev: v2.2.6
hooks:
Expand Down
2 changes: 1 addition & 1 deletion anndata/_core/anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class StorageType(Enum):
DaskArray = DaskArray
CupyArray = CupyArray
CupySparseMatrix = CupySparseMatrix
BackedSparseMAtrix = BaseCompressedSparseDataset
BackedSparseMatrix = BaseCompressedSparseDataset

@classmethod
def classes(cls):
Expand Down
12 changes: 9 additions & 3 deletions anndata/_core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -1108,12 +1108,18 @@ def concat(
... X=sparse.csr_matrix(np.array([[0, 1], [2, 3]])),
... obs=pd.DataFrame({"group": ["a", "b"]}, index=["s1", "s2"]),
... var=pd.DataFrame(index=["var1", "var2"]),
... varm={"ones": np.ones((2, 5)), "rand": np.random.randn(2, 3), "zeros": np.zeros((2, 5))},
... varm={
... "ones": np.ones((2, 5)),
... "rand": np.random.randn(2, 3),
... "zeros": np.zeros((2, 5)),
... },
... uns={"a": 1, "b": 2, "c": {"c.a": 3, "c.b": 4}},
... )
>>> b = ad.AnnData(
... X=sparse.csr_matrix(np.array([[4, 5, 6], [7, 8, 9]])),
... obs=pd.DataFrame({"group": ["b", "c"], "measure": [1.2, 4.3]}, index=["s3", "s4"]),
... obs=pd.DataFrame(
... {"group": ["b", "c"], "measure": [1.2, 4.3]}, index=["s3", "s4"]
... ),
... var=pd.DataFrame(index=["var1", "var2", "var3"]),
... varm={"ones": np.ones((3, 5)), "rand": np.random.randn(3, 5)},
... uns={"a": 1, "b": 3, "c": {"c.b": 4}},
Expand Down Expand Up @@ -1147,7 +1153,7 @@ def concat(
>>> (inner.obs_names, inner.var_names) # doctest: +NORMALIZE_WHITESPACE
(Index(['s1', 's2', 's3', 's4'], dtype='object'),
Index(['var1', 'var2'], dtype='object'))
>>> outer = ad.concat([a, b], join="outer") # Joining on union of variables
>>> outer = ad.concat([a, b], join="outer") # Joining on union of variables
>>> outer
AnnData object with n_obs × n_vars = 4 × 3
obs: 'group', 'measure'
Expand Down
12 changes: 6 additions & 6 deletions anndata/_core/sparse_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,12 +547,12 @@ def sparse_dataset(group: ZarrGroup | H5Group) -> CSRDataset | CSCDataset:
>>> import zarr
>>> from anndata.experimental import sparse_dataset
>>> group = zarr.open_group('./my_test_store.zarr')
>>> group['data'] = [10, 20, 30, 40, 50, 60, 70, 80]
>>> group['indices'] = [0, 1, 1, 3, 2, 3, 4, 5]
>>> group['indptr'] = [0, 2, 4, 7, 8]
>>> group.attrs['shape'] = (4, 6)
>>> group.attrs['encoding-type'] = 'csr_matrix'
>>> group = zarr.open_group("./my_test_store.zarr")
>>> group["data"] = [10, 20, 30, 40, 50, 60, 70, 80]
>>> group["indices"] = [0, 1, 1, 3, 2, 3, 4, 5]
>>> group["indptr"] = [0, 2, 4, 7, 8]
>>> group.attrs["shape"] = (4, 6)
>>> group.attrs["encoding-type"] = "csr_matrix"
>>> sparse_dataset(group)
CSRDataset: backend zarr, shape (4, 6), data_dtype int64
"""
Expand Down
19 changes: 13 additions & 6 deletions anndata/_io/h5ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from types import MappingProxyType
from typing import (
TYPE_CHECKING,
Any,
Callable,
Literal,
TypeVar,
Expand Down Expand Up @@ -112,7 +113,13 @@ def write_h5ad(

@report_write_key_on_error
@write_spec(IOSpec("array", "0.2.0"))
def write_sparse_as_dense(f, key, value, dataset_kwargs=MappingProxyType({})):
def write_sparse_as_dense(
f: h5py.Group,
key: str,
value: sparse.spmatrix | BaseCompressedSparseDataset,
*,
dataset_kwargs: Mapping[str, Any] = MappingProxyType({}),
):
real_key = None # Flag for if temporary key was used
if key in f:
if isinstance(value, BaseCompressedSparseDataset) and (
Expand Down Expand Up @@ -269,7 +276,7 @@ def callback(func, elem_name: str, elem, iospec):
def _read_raw(
f: h5py.File | AnnDataFileManager,
as_sparse: Collection[str] = (),
rdasp: Callable[[h5py.Dataset], sparse.spmatrix] = None,
rdasp: Callable[[h5py.Dataset], sparse.spmatrix] | None = None,
*,
attrs: Collection[str] = ("X", "var", "varm"),
) -> dict:
Expand All @@ -286,7 +293,7 @@ def _read_raw(


@report_read_key_on_error
def read_dataframe_legacy(dataset) -> pd.DataFrame:
def read_dataframe_legacy(dataset: h5py.Dataset) -> pd.DataFrame:
"""Read pre-anndata 0.7 dataframes."""
warn(
f"'{dataset.name}' was written with a very old version of AnnData. "
Expand All @@ -305,7 +312,7 @@ def read_dataframe_legacy(dataset) -> pd.DataFrame:
return df


def read_dataframe(group) -> pd.DataFrame:
def read_dataframe(group: h5py.Group | h5py.Dataset) -> pd.DataFrame:
"""Backwards compat function"""
if not isinstance(group, h5py.Group):
return read_dataframe_legacy(group)
Expand Down Expand Up @@ -352,7 +359,7 @@ def read_dense_as_sparse(
raise ValueError(f"Cannot read dense array as type: {sparse_format}")


def read_dense_as_csr(dataset, axis_chunk=6000):
def read_dense_as_csr(dataset: h5py.Dataset, axis_chunk: int = 6000):
sub_matrices = []
for idx in idx_chunks_along_axis(dataset.shape, 0, axis_chunk):
dense_chunk = dataset[idx]
Expand All @@ -361,7 +368,7 @@ def read_dense_as_csr(dataset, axis_chunk=6000):
return sparse.vstack(sub_matrices, format="csr")


def read_dense_as_csc(dataset, axis_chunk=6000):
def read_dense_as_csc(dataset: h5py.Dataset, axis_chunk: int = 6000):
sub_matrices = []
for idx in idx_chunks_along_axis(dataset.shape, 1, axis_chunk):
sub_matrix = sparse.csc_matrix(dataset[idx])
Expand Down
81 changes: 36 additions & 45 deletions anndata/_io/specs/registry.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from collections.abc import Callable, Iterable, Mapping
from collections.abc import Mapping
from dataclasses import dataclass
from functools import singledispatch, wraps
from types import MappingProxyType
Expand All @@ -10,12 +10,13 @@
from anndata.compat import _read_attr

if TYPE_CHECKING:
from collections.abc import Callable, Generator, Iterable

from anndata._types import GroupStorageType, StorageType


# TODO: This probably should be replaced by a hashable Mapping due to conversion b/w "_" and "-"
# TODO: Should filetype be included in the IOSpec if it changes the encoding? Or does the intent that these things be "the same" overrule that?


@dataclass(frozen=True)
class IOSpec:
encoding_type: str
Expand All @@ -25,7 +26,9 @@ class IOSpec:
# TODO: Should this subclass from LookupError?
class IORegistryError(Exception):
@classmethod
def _from_write_parts(cls, dest_type, typ, modifiers) -> IORegistryError:
def _from_write_parts(
cls, dest_type: type, typ: type, modifiers: frozenset[str]
) -> IORegistryError:
msg = f"No method registered for writing {typ} into {dest_type}"
if modifiers:
msg += f" with {modifiers}"
Expand All @@ -36,7 +39,7 @@ def _from_read_parts(
cls,
method: str,
registry: Mapping,
src_typ: StorageType,
src_typ: type[StorageType],
spec: IOSpec,
) -> IORegistryError:
# TODO: Improve error message if type exists, but version does not
Expand All @@ -50,7 +53,7 @@ def _from_read_parts(
def write_spec(spec: IOSpec):
def decorator(func: Callable):
@wraps(func)
def wrapper(g, k, *args, **kwargs):
def wrapper(g: GroupStorageType, k: str, *args, **kwargs):
result = func(g, k, *args, **kwargs)
g[k].attrs.setdefault("encoding-type", spec.encoding_type)
g[k].attrs.setdefault("encoding-version", spec.encoding_version)
Expand Down Expand Up @@ -193,12 +196,12 @@ def proc_spec(spec) -> IOSpec:


@proc_spec.register(IOSpec)
def proc_spec_spec(spec) -> IOSpec:
def proc_spec_spec(spec: IOSpec) -> IOSpec:
return spec


@proc_spec.register(Mapping)
def proc_spec_mapping(spec) -> IOSpec:
def proc_spec_mapping(spec: Mapping[str, str]) -> IOSpec:
return IOSpec(**{k.replace("-", "_"): v for k, v in spec.items()})


Expand All @@ -213,7 +216,9 @@ def get_spec(
)


def _iter_patterns(elem):
def _iter_patterns(
elem,
) -> Generator[tuple[type, type | str] | tuple[type, type, str], None, None]:
"""Iterates over possible patterns for an element in order of precedence."""
from anndata.compat import DaskArray

Expand All @@ -236,40 +241,27 @@ def __init__(self, registry: IORegistry, callback: Callable | None = None) -> No
def read_elem(
self,
elem: StorageType,
modifiers: frozenset(str) = frozenset(),
modifiers: frozenset[str] = frozenset(),
) -> Any:
"""Read an element from a store. See exported function for more details."""
from functools import partial

read_func = self.registry.get_reader(
type(elem), get_spec(elem), frozenset(modifiers)
iospec = get_spec(elem)
read_func = partial(
self.registry.get_reader(type(elem), iospec, modifiers),
_reader=self,
)
read_func = partial(read_func, _reader=self)
if self.callback is not None:
return self.callback(read_func, elem.name, elem, iospec=get_spec(elem))
else:
if self.callback is None:
return read_func(elem)
return self.callback(read_func, elem.name, elem, iospec=iospec)


class Writer:
def __init__(
self,
registry: IORegistry,
callback: Callable[
[
GroupStorageType,
str,
StorageType,
dict,
],
None,
]
| None = None,
):
def __init__(self, registry: IORegistry, callback: Callable | None = None):
self.registry = registry
self.callback = callback

def find_writer(self, dest_type, elem, modifiers):
def find_writer(self, dest_type: type, elem, modifiers: frozenset[str]):
for pattern in _iter_patterns(elem):
if self.registry.has_writer(dest_type, pattern, modifiers):
return self.registry.get_writer(dest_type, pattern, modifiers)
Expand All @@ -281,10 +273,10 @@ def write_elem(
self,
store: GroupStorageType,
k: str,
elem,
elem: Any,
*,
dataset_kwargs=MappingProxyType({}),
modifiers=frozenset(),
dataset_kwargs: Mapping[str, Any] = MappingProxyType({}),
modifiers: frozenset[str] = frozenset(),
):
from functools import partial
from pathlib import PurePosixPath
Expand Down Expand Up @@ -313,17 +305,16 @@ def write_elem(
_writer=self,
)

if self.callback is not None:
return self.callback(
write_func,
store,
k,
elem,
dataset_kwargs=dataset_kwargs,
iospec=self.registry.get_spec(elem),
)
else:
if self.callback is None:
return write_func(store, k, elem, dataset_kwargs=dataset_kwargs)
return self.callback(
write_func,
store,
k,
elem,
dataset_kwargs=dataset_kwargs,
iospec=self.registry.get_spec(elem),
)


def read_elem(elem: StorageType) -> Any:
Expand All @@ -346,7 +337,7 @@ def write_elem(
k: str,
elem: Any,
*,
dataset_kwargs: Mapping = MappingProxyType({}),
dataset_kwargs: Mapping[str, Any] = MappingProxyType({}),
) -> None:
"""
Write an element to a storage group using anndata encoding.
Expand Down
Loading

0 comments on commit 27f9f2c

Please sign in to comment.