Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace spmatrix with _CSMatrix as appropriate #3431

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
3 changes: 0 additions & 3 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,6 @@ def setup(app: Sphinx):
("py:class", "scanpy._utils.Empty"),
("py:class", "numpy.random.mtrand.RandomState"),
("py:class", "scanpy.neighbors._types.KnnTransformerLike"),
# Will work once scipy 1.8 is released
("py:class", "scipy.sparse.base.spmatrix"),
("py:class", "scipy.sparse.csr.csr_matrix"),
]

# Options for plot examples
Expand Down
69 changes: 40 additions & 29 deletions src/scanpy/_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,23 @@
else:
from anndata._core.sparse_dataset import SparseDataset

_CSMatrix = sparse.csr_matrix | sparse.csc_matrix

if TYPE_CHECKING:
from collections.abc import Callable, Iterable, KeysView, Mapping
from pathlib import Path
from typing import Any, TypeVar

from anndata import AnnData
from igraph import Graph
from numpy.typing import ArrayLike, DTypeLike, NDArray

from .._compat import _LegacyRandom
from ..neighbors import NeighborsParams, RPForestDict

_MemoryArray = NDArray | _CSMatrix
_SupportedArray = _MemoryArray | DaskArray


SeedLike = int | np.integer | Sequence[int] | np.random.SeedSequence
RNGLike = np.random.Generator | np.random.BitGenerator
Expand Down Expand Up @@ -195,7 +201,7 @@
try:
obj = getattr(obj, name)
except AttributeError:
raise RuntimeError(f"{parts[:i]}, {parts[i + 1:]}, {obj} {name}")
raise RuntimeError(f"{parts[:i]}, {parts[i + 1 :]}, {obj} {name}")

Check warning on line 204 in src/scanpy/_utils/__init__.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/_utils/__init__.py#L204

Added line #L204 was not covered by tests
return obj


Expand Down Expand Up @@ -280,7 +286,7 @@
# --------------------------------------------------------------------------------


def get_igraph_from_adjacency(adjacency, directed=None):
def get_igraph_from_adjacency(adjacency: _CSMatrix, *, directed: bool = False) -> Graph:
"""Get igraph graph from adjacency matrix."""
import igraph as ig

Expand Down Expand Up @@ -358,8 +364,7 @@
for cat in cats:
if cat in settings.categories_to_ignore:
logg.info(
f"Ignoring category {cat!r} "
"as it’s in `settings.categories_to_ignore`."
f"Ignoring category {cat!r} as it’s in `settings.categories_to_ignore`."
)
asso_names: list[str] = []
asso_matrix: list[list[float]] = []
Expand Down Expand Up @@ -564,21 +569,16 @@
# --------------------------------------------------------------------------------


if TYPE_CHECKING:
_SparseMatrix = sparse.csr_matrix | sparse.csc_matrix
_MemoryArray = NDArray | _SparseMatrix
_SupportedArray = _MemoryArray | DaskArray


@singledispatch
def elem_mul(x: _SupportedArray, y: _SupportedArray) -> _SupportedArray:
raise NotImplementedError


@elem_mul.register(np.ndarray)
@elem_mul.register(sparse.spmatrix)
@elem_mul.register(sparse.csc_matrix)
@elem_mul.register(sparse.csr_matrix)
def _elem_mul_in_mem(x: _MemoryArray, y: _MemoryArray) -> _MemoryArray:
if isinstance(x, sparse.spmatrix):
if isinstance(x, _CSMatrix):
# returns coo_matrix, so cast back to input type
return type(x)(x.multiply(y))
return x * y
Expand Down Expand Up @@ -629,14 +629,14 @@
@axis_mul_or_truediv.register(sparse.csr_matrix)
@axis_mul_or_truediv.register(sparse.csc_matrix)
def _(
X: sparse.csr_matrix | sparse.csc_matrix,
X: _CSMatrix,
scaling_array,
axis: Literal[0, 1],
op: Callable[[Any, Any], Any],
*,
allow_divide_by_zero: bool = True,
out: sparse.csr_matrix | sparse.csc_matrix | None = None,
) -> sparse.csr_matrix | sparse.csc_matrix:
out: _CSMatrix | None = None,
) -> _CSMatrix:
check_op(op)
if out is not None and X.data is not out.data:
raise ValueError(
Expand Down Expand Up @@ -770,13 +770,22 @@
) -> np.matrix: ...


@singledispatch
@overload
def axis_sum(
X: np.ndarray,
*,
axis: tuple[Literal[0, 1], ...] | Literal[0, 1] | None = None,
dtype: DTypeLike | None = None,
) -> np.ndarray:
) -> np.ndarray: ...


@singledispatch
def axis_sum(
X: np.ndarray | sparse.spmatrix,
*,
axis: tuple[Literal[0, 1], ...] | Literal[0, 1] | None = None,
dtype: DTypeLike | None = None,
) -> np.ndarray | np.matrix:
return np.sum(X, axis=axis, dtype=dtype)


Expand Down Expand Up @@ -832,7 +841,8 @@


@check_nonnegative_integers.register(np.ndarray)
@check_nonnegative_integers.register(sparse.spmatrix)
@check_nonnegative_integers.register(sparse.csr_matrix)
@check_nonnegative_integers.register(sparse.csc_matrix)
def _check_nonnegative_integers_in_mem(X: _MemoryArray) -> bool:
from numbers import Integral

Expand Down Expand Up @@ -1128,23 +1138,24 @@
return key in self._neighbors_dict


def _choose_graph(adata, obsp, neighbors_key):
"""Choose connectivities from neighbbors or another obsp column"""
def _choose_graph(
adata: AnnData, obsp: str | None, neighbors_key: str | None
) -> _CSMatrix:
"""Choose connectivities from neighbbors or another obsp entry."""
if obsp is not None and neighbors_key is not None:
raise ValueError(
"You can't specify both obsp, neighbors_key. " "Please select only one."
"You can't specify both obsp, neighbors_key. Please select only one."
)

if obsp is not None:
return adata.obsp[obsp]
else:
neighbors = NeighborsView(adata, neighbors_key)
if "connectivities" not in neighbors:
raise ValueError(
"You need to run `pp.neighbors` first "
"to compute a neighborhood graph."
)
return neighbors["connectivities"]

neighbors = NeighborsView(adata, neighbors_key)
if "connectivities" not in neighbors:
raise ValueError(

Check warning on line 1155 in src/scanpy/_utils/__init__.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/_utils/__init__.py#L1155

Added line #L1155 was not covered by tests
"You need to run `pp.neighbors` first to compute a neighborhood graph."
)
return neighbors["connectivities"]


def _resolve_axis(
Expand Down
18 changes: 10 additions & 8 deletions src/scanpy/_utils/compute/is_constant.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from __future__ import annotations

from collections.abc import Callable
from functools import partial, singledispatch, wraps
from numbers import Integral
from typing import TYPE_CHECKING, TypeVar, overload
from typing import TYPE_CHECKING, overload

import numba
import numpy as np
Expand All @@ -12,11 +11,16 @@
from ..._compat import DaskArray, njit

if TYPE_CHECKING:
from typing import Literal
from collections.abc import Callable
from typing import Literal, TypeVar

from numpy.typing import NDArray

C = TypeVar("C", bound=Callable)
from ..._utils import _CSMatrix

_Array = NDArray | DaskArray | _CSMatrix

C = TypeVar("C", bound=Callable)


def _check_axis_supported(wrapped: C) -> C:
Expand All @@ -33,11 +37,9 @@ def func(a, axis=None):


@overload
def is_constant(a: NDArray, axis: None = None) -> bool: ...


def is_constant(a: _Array, axis: None = None) -> bool: ...
@overload
def is_constant(a: NDArray, axis: Literal[0, 1]) -> NDArray[np.bool_]: ...
def is_constant(a: _Array, axis: Literal[0, 1]) -> NDArray[np.bool_]: ...


@_check_axis_supported
Expand Down
4 changes: 2 additions & 2 deletions src/scanpy/external/tl/_harmony_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ def harmony_timeseries(

**X_harmony** - :class:`~numpy.ndarray` (:attr:`~anndata.AnnData.obsm`, dtype `float`)
force directed layout
**harmony_aff** - :class:`~scipy.sparse.spmatrix` (:attr:`~anndata.AnnData.obsp`, dtype `float`)
**harmony_aff** - :class:`~scipy.sparse.csr_matrix` (:attr:`~anndata.AnnData.obsp`, dtype `float`)
affinity matrix
**harmony_aff_aug** - :class:`~scipy.sparse.spmatrix` (:attr:`~anndata.AnnData.obsp`, dtype `float`)
**harmony_aff_aug** - :class:`~scipy.sparse.csr_matrix` (:attr:`~anndata.AnnData.obsp`, dtype `float`)
augmented affinity matrix
**harmony_timepoint_var** - `str` (:attr:`~anndata.AnnData.uns`)
The name of the variable passed as `tp`
Expand Down
2 changes: 1 addition & 1 deletion src/scanpy/external/tl/_palantir.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def palantir(
Array of Diffusion components.
- palantir_EigenValues - :class:`~numpy.ndarray` (:attr:`~anndata.AnnData.uns`, dtype `float`)
Array of corresponding eigen values.
- palantir_diff_op - :class:`~scipy.sparse.spmatrix` (:attr:`~anndata.AnnData.obsp`, dtype `float`)
- palantir_diff_op - :class:`~scipy.sparse.csr_matrix` (:attr:`~anndata.AnnData.obsp`, dtype `float`)
The diffusion operator matrix.

**Multi scale space results**,
Expand Down
5 changes: 3 additions & 2 deletions src/scanpy/get/_aggregated.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,9 +333,10 @@ def aggregate_df(data, by, func, *, mask=None, dof=1):


@_aggregate.register(np.ndarray)
@_aggregate.register(sparse.spmatrix)
@_aggregate.register(sparse.csr_matrix)
@_aggregate.register(sparse.csc_matrix)
def aggregate_array(
data,
data: Array,
by: pd.Categorical,
func: AggType | Iterable[AggType],
*,
Expand Down
13 changes: 6 additions & 7 deletions src/scanpy/get/get.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,18 @@
from anndata import AnnData
from numpy.typing import NDArray
from packaging.version import Version
from scipy.sparse import spmatrix

from .._utils import _CSMatrix

if TYPE_CHECKING:
from collections.abc import Collection, Iterable
from typing import Any, Literal

from anndata._core.sparse_dataset import BaseCompressedSparseDataset
from anndata._core.views import ArrayView
from scipy.sparse import csc_matrix, csr_matrix

from .._compat import DaskArray

CSMatrix = csr_matrix | csc_matrix

# --------------------------------------------------------------------------------
# Plotting data helpers
Expand Down Expand Up @@ -333,7 +332,7 @@ def obs_df(
val = adata.obsm[k]
if isinstance(val, np.ndarray):
df[added_k] = np.ravel(val[:, idx])
elif isinstance(val, spmatrix):
elif isinstance(val, _CSMatrix):
df[added_k] = np.ravel(val[:, idx].toarray())
elif isinstance(val, pd.DataFrame):
df[added_k] = val.loc[:, idx]
Expand Down Expand Up @@ -403,7 +402,7 @@ def var_df(
val = adata.varm[k]
if isinstance(val, np.ndarray):
df[added_k] = np.ravel(val[:, idx])
elif isinstance(val, spmatrix):
elif isinstance(val, _CSMatrix):
df[added_k] = np.ravel(val[:, idx].toarray())
elif isinstance(val, pd.DataFrame):
df[added_k] = val.loc[:, idx]
Expand All @@ -419,7 +418,7 @@ def _get_obs_rep(
obsp: str | None = None,
) -> (
np.ndarray
| spmatrix
| _CSMatrix
| pd.DataFrame
| ArrayView
| BaseCompressedSparseDataset
Expand Down Expand Up @@ -494,7 +493,7 @@ def _set_obs_rep(


def _check_mask(
data: AnnData | np.ndarray | CSMatrix | DaskArray,
data: AnnData | np.ndarray | _CSMatrix | DaskArray,
mask: str | M,
dim: Literal["obs", "var"],
*,
Expand Down
13 changes: 9 additions & 4 deletions src/scanpy/metrics/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,24 @@


@singledispatch
def _resolve_vals(val: NDArray | sparse.spmatrix) -> NDArray | sparse.csr_matrix:
return np.asarray(val)
def _resolve_vals(
val: NDArray | sparse.spmatrix | DaskArray,
) -> NDArray | sparse.csr_matrix | DaskArray:
msg = f"Unsupported type {type(val)}"
raise TypeError(msg)

Check warning on line 22 in src/scanpy/metrics/_common.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/metrics/_common.py#L21-L22

Added lines #L21 - L22 were not covered by tests
Comment on lines +18 to +22
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this type hint make sense? Why not Any?



@_resolve_vals.register(np.ndarray)
@_resolve_vals.register(sparse.csr_matrix)
@_resolve_vals.register(DaskArray)
def _(val):
def _(
val: np.ndarray | sparse.csr_matrix | DaskArray,
) -> np.ndarray | sparse.csr_matrix | DaskArray:
return val


@_resolve_vals.register(sparse.spmatrix)
def _(val):
def _(val: sparse.spmatrix) -> sparse.csr_matrix:
return sparse.csr_matrix(val)


Expand Down
9 changes: 7 additions & 2 deletions src/scanpy/metrics/_gearys_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@

if TYPE_CHECKING:
from anndata import AnnData
from numpy.typing import NDArray

from .._compat import DaskArray


@singledispatch
def gearys_c(
adata: AnnData,
*,
vals: np.ndarray | sparse.spmatrix | None = None,
vals: NDArray | sparse.spmatrix | DaskArray | None = None,
use_graph: str | None = None,
layer: str | None = None,
obsm: str | None = None,
Expand Down Expand Up @@ -289,7 +292,9 @@ def _gearys_c_mtx_csr( # noqa: PLR0917


@gearys_c.register(sparse.csr_matrix)
def _gearys_c(g: sparse.csr_matrix, vals: np.ndarray | sparse.spmatrix) -> np.ndarray:
def _gearys_c(
g: sparse.csr_matrix, vals: NDArray | sparse.spmatrix | DaskArray
) -> np.ndarray:
assert g.shape[0] == g.shape[1], "`g` should be a square adjacency matrix"
vals = _resolve_vals(vals)
g_data = g.data.astype(np.float64, copy=False)
Expand Down
9 changes: 7 additions & 2 deletions src/scanpy/metrics/_morans_i.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@

if TYPE_CHECKING:
from anndata import AnnData
from numpy.typing import NDArray

from .._compat import DaskArray


@singledispatch
def morans_i(
adata: AnnData,
*,
vals: np.ndarray | sparse.spmatrix | None = None,
vals: NDArray | sparse.spmatrix | DaskArray | None = None,
use_graph: str | None = None,
layer: str | None = None,
obsm: str | None = None,
Expand Down Expand Up @@ -225,7 +228,9 @@ def _morans_i_mtx_csr( # noqa: PLR0917


@morans_i.register(sparse.csr_matrix)
def _morans_i(g: sparse.csr_matrix, vals: np.ndarray | sparse.spmatrix) -> np.ndarray:
def _morans_i(
g: sparse.csr_matrix, vals: NDArray | sparse.spmatrix | DaskArray
) -> np.ndarray:
assert g.shape[0] == g.shape[1], "`g` should be a square adjacency matrix"
vals = _resolve_vals(vals)
g_data = g.data.astype(np.float64, copy=False)
Expand Down
Loading
Loading