Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(fix): empty boolean mask on backed sparse matrix #1321

Merged
merged 12 commits into from
Jan 25, 2024
8 changes: 6 additions & 2 deletions anndata/_core/sparse_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ def _get_sliceXslice(self, row: slice, col: slice) -> ss.csr_matrix:

def _get_arrayXslice(self, row: Sequence[int], col: slice) -> ss.csr_matrix:
idxs = np.asarray(row)
if len(idxs) == 0:
return ss.csr_matrix((0, self.shape[1]))
if idxs.dtype == bool:
idxs = np.where(idxs)
return ss.csr_matrix(
Expand Down Expand Up @@ -214,6 +216,8 @@ def _get_sliceXslice(self, row: slice, col: slice) -> ss.csc_matrix:

def _get_sliceXarray(self, row: slice, col: Sequence[int]) -> ss.csc_matrix:
idxs = np.asarray(col)
if len(idxs) == 0:
return ss.csc_matrix((self.shape[0], 0))
if idxs.dtype == bool:
idxs = np.where(idxs)
return ss.csc_matrix(
Expand Down Expand Up @@ -409,11 +413,11 @@ def __getitem__(self, index: Index | tuple[()]) -> float | ss.spmatrix:
mtx = self._to_backed()

# Handle masked indexing along major axis
if self.format == "csr" and np.array(row).dtype == bool:
if self.format == "csr" and np.array(row).dtype == bool and row.sum() != 0:
ilan-gold marked this conversation as resolved.
Show resolved Hide resolved
sub = ss.csr_matrix(
subset_by_major_axis_mask(mtx, row), shape=(row.sum(), mtx.shape[1])
)[:, col]
elif self.format == "csc" and np.array(col).dtype == bool:
elif self.format == "csc" and np.array(col).dtype == bool and col.sum() != 0:
sub = ss.csc_matrix(
subset_by_major_axis_mask(mtx, col), shape=(mtx.shape[0], col.sum())
)[row, :]
Expand Down
26 changes: 25 additions & 1 deletion anndata/tests/test_backed_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from anndata.tests.helpers import AccessTrackingStore, assert_equal, subset_func

if TYPE_CHECKING:
from collections.abc import Iterable
from pathlib import Path

from numpy.typing import ArrayLike
Expand All @@ -27,6 +28,10 @@ def diskfmt(request):
return request.param


M = 50
N = 50


@pytest.fixture(scope="function")
def ondisk_equivalent_adata(
tmp_path: Path, diskfmt: Literal["h5ad", "zarr"]
Expand All @@ -37,7 +42,7 @@ def ondisk_equivalent_adata(

write = lambda x, pth, **kwargs: getattr(x, f"write_{diskfmt}")(pth, **kwargs)

csr_mem = ad.AnnData(X=sparse.random(50, 50, format="csr", density=0.1))
csr_mem = ad.AnnData(X=sparse.random(M, N, format="csr", density=0.1))
csc_mem = ad.AnnData(X=csr_mem.X.tocsc())
dense_mem = ad.AnnData(X=csr_mem.X.toarray())

Expand Down Expand Up @@ -77,6 +82,25 @@ def callback(func, elem_name, elem, iospec):
return csr_mem, csr_disk, csc_disk, dense_disk


@pytest.mark.parametrize(
"empty_mask", [[], np.zeros(M, dtype=bool)], ids=["empty_list", "empty_bool_mask"]
)
def test_empty_backed_indexing(
ilan-gold marked this conversation as resolved.
Show resolved Hide resolved
ondisk_equivalent_adata: tuple[AnnData, AnnData, AnnData, AnnData],
empty_mask: Iterable[bool],
ilan-gold marked this conversation as resolved.
Show resolved Hide resolved
):
csr_mem, csr_disk, csc_disk, _ = ondisk_equivalent_adata

assert_equal(csr_mem.X[empty_mask], csr_disk.X[empty_mask])
assert_equal(csr_mem.X[:, empty_mask], csc_disk.X[:, empty_mask])

# The following do not work because of https://github.com/scipy/scipy/issues/19919
# Our implementation returns a (0,0) sized matrix but scipy does (1,0).

# assert_equal(csr_mem.X[empty_mask, empty_mask], csr_disk.X[empty_mask, empty_mask])
# assert_equal(csr_mem.X[empty_mask, empty_mask], csc_disk.X[empty_mask, empty_mask])


def test_backed_indexing(
ondisk_equivalent_adata: tuple[AnnData, AnnData, AnnData, AnnData],
subset_func,
Expand Down
Loading