Skip to content

Commit

Permalink
Backport PR #3170 on branch 1.10.x (Refactor score_genes) (#3171)
Browse files Browse the repository at this point in the history
  • Loading branch information
flying-sheep authored Jul 26, 2024
1 parent 7d8d8d1 commit d0aad7b
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 36 deletions.
113 changes: 79 additions & 34 deletions src/scanpy/tools/_score_genes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ..get import _get_obs_rep

if TYPE_CHECKING:
from collections.abc import Sequence
from collections.abc import Callable, Generator, Sequence
from typing import Literal

from anndata import AnnData
Expand All @@ -24,6 +24,12 @@

from .._utils import AnyRandom

try:
_StrIdx = pd.Index[str]
except TypeError: # Sphinx
_StrIdx = pd.Index
_GetSubset = Callable[[_StrIdx], np.ndarray | csr_matrix | csc_matrix]


def _sparse_nanmean(
X: csr_matrix | csc_matrix, axis: Literal[0, 1]
Expand Down Expand Up @@ -128,6 +134,65 @@ def score_genes(
if random_state is not None:
np.random.seed(random_state)

gene_list, gene_pool, get_subset = _check_score_genes_args(
adata, gene_list, gene_pool, use_raw=use_raw
)
del use_raw, random_state

# Trying here to match the Seurat approach in scoring cells.
# Basically we need to compare genes against random genes in a matched
# interval of expression.

control_genes = pd.Index([], dtype="string")
for r_genes in _score_genes_bins(
gene_list,
gene_pool,
ctrl_as_ref=ctrl_as_ref,
ctrl_size=ctrl_size,
n_bins=n_bins,
get_subset=get_subset,
):
control_genes = control_genes.union(r_genes)

if len(control_genes) == 0:
msg = "No control genes found in any cut."
if ctrl_as_ref:
msg += " Try setting `ctrl_as_ref=False`."
raise RuntimeError(msg)

means_list, means_control = (
_nan_means(get_subset(genes), axis=1, dtype="float64")
for genes in (gene_list, control_genes)
)
score = means_list - means_control

adata.obs[score_name] = pd.Series(
np.array(score).ravel(), index=adata.obs_names, dtype="float64"
)

logg.info(
" finished",
time=start,
deep=(
"added\n"
f" {score_name!r}, score of gene set (adata.obs).\n"
f" {len(control_genes)} total control genes are used."
),
)
return adata if copy else None


def _check_score_genes_args(
adata: AnnData,
gene_list: pd.Index[str] | Sequence[str],
gene_pool: pd.Index[str] | Sequence[str] | None,
*,
use_raw: bool,
) -> tuple[pd.Index[str], pd.Index[str], _GetSubset]:
"""Restrict `gene_list` and `gene_pool` to present genes in `adata`.
Also returns a function to get subset of `adata.X` based on a set of genes passed.
"""
var_names = adata.raw.var_names if use_raw else adata.var_names
gene_list = pd.Index([gene_list] if isinstance(gene_list, str) else gene_list)
genes_to_ignore = gene_list.difference(var_names, sort=False) # first get missing
Expand All @@ -144,17 +209,25 @@ def score_genes(
if len(gene_pool) == 0:
raise ValueError("No valid genes were passed for reference set.")

# Trying here to match the Seurat approach in scoring cells.
# Basically we need to compare genes against random genes in a matched
# interval of expression.

def get_subset(genes: pd.Index[str]):
x = _get_obs_rep(adata, use_raw=use_raw)
if len(genes) == len(var_names):
return x
idx = var_names.get_indexer(genes)
return x[:, idx]

return gene_list, gene_pool, get_subset


def _score_genes_bins(
gene_list: pd.Index[str],
gene_pool: pd.Index[str],
*,
ctrl_as_ref: bool,
ctrl_size: int,
n_bins: int,
get_subset: _GetSubset,
) -> Generator[pd.Index[str], None, None]:
# average expression of genes
obs_avg = pd.Series(_nan_means(get_subset(gene_pool), axis=0), index=gene_pool)
# Sometimes (and I don’t know how) missing data may be there, with NaNs for missing entries
Expand All @@ -165,7 +238,6 @@ def get_subset(genes: pd.Index[str]):
keep_ctrl_in_obs_cut = False if ctrl_as_ref else obs_cut.index.isin(gene_list)

# now pick `ctrl_size` genes from every cut
control_genes = pd.Index([], dtype="string")
for cut in np.unique(obs_cut.loc[gene_list]):
r_genes: pd.Index[str] = obs_cut[(obs_cut == cut) & ~keep_ctrl_in_obs_cut].index
if len(r_genes) == 0:
Expand All @@ -178,34 +250,7 @@ def get_subset(genes: pd.Index[str]):
r_genes = r_genes.to_series().sample(ctrl_size).index
if ctrl_as_ref: # otherwise `r_genes` is already filtered
r_genes = r_genes.difference(gene_list)
control_genes = control_genes.union(r_genes)

if len(control_genes) == 0:
msg = "No control genes found in any cut."
if ctrl_as_ref:
msg += " Try setting `ctrl_as_ref=False`."
raise RuntimeError(msg)

means_list, means_control = (
_nan_means(get_subset(genes), axis=1, dtype="float64")
for genes in (gene_list, control_genes)
)
score = means_list - means_control

adata.obs[score_name] = pd.Series(
np.array(score).ravel(), index=adata.obs_names, dtype="float64"
)

logg.info(
" finished",
time=start,
deep=(
"added\n"
f" {score_name!r}, score of gene set (adata.obs).\n"
f" {len(control_genes)} total control genes are used."
),
)
return adata if copy else None
yield r_genes


def _nan_means(
Expand Down
7 changes: 5 additions & 2 deletions tests/test_score_genes.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,15 +256,18 @@ def test_no_control_gene():
sc.tl.score_genes(adata, adata.var_names[:1], ctrl_size=1)


@pytest.mark.parametrize("ctrl_as_ref", [True, False])
@pytest.mark.parametrize(
"ctrl_as_ref", [True, False], ids=["ctrl_as_ref", "no_ctrl_as_ref"]
)
def test_gene_list_is_control(ctrl_as_ref: bool):
np.random.seed(0)
adata = sc.datasets.blobs(n_variables=10, n_observations=100, n_centers=20)
adata.var_names = "g" + adata.var_names
with (
pytest.raises(RuntimeError, match=r"No control genes found in any cut")
if ctrl_as_ref
else nullcontext()
):
sc.tl.score_genes(
adata, gene_list="3", ctrl_size=1, n_bins=5, ctrl_as_ref=ctrl_as_ref
adata, gene_list="g3", ctrl_size=1, n_bins=5, ctrl_as_ref=ctrl_as_ref
)

0 comments on commit d0aad7b

Please sign in to comment.