Skip to content

Commit

Permalink
Added option to use raw data
Browse files Browse the repository at this point in the history
  • Loading branch information
earmingol committed Dec 21, 2023
1 parent 283548b commit 0b04799
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 33 deletions.
10 changes: 8 additions & 2 deletions sccellfie/expression/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pandas as pd


def agg_expression_cells(adata, groupby, layer=None, gene_symbols=None, agg_func='mean'):
def agg_expression_cells(adata, groupby, layer=None, gene_symbols=None, agg_func='mean', use_raw=False):
"""
Aggregates gene expression data for specified cell groups in an `AnnData` object.
Expand All @@ -28,6 +28,9 @@ def agg_expression_cells(adata, groupby, layer=None, gene_symbols=None, agg_func
'25p' (25th percentile), and '75p' (75th percentile). The function
must be one of the keys in the `AGG_FUNC` dictionary.
use_raw : bool (default=False)
Whether to use the data in adata.raw.X (True) or in adata.X (False).
Returns
-------
agg_expression : pd.DataFrame
Expand Down Expand Up @@ -55,7 +58,10 @@ def agg_expression_cells(adata, groupby, layer=None, gene_symbols=None, agg_func
if layer is not None:
X = adata.layers[layer]
else:
X = adata.X
if use_raw:
X = adata.raw.X.toarray()
else:
X = adata.X.toarray()

# Filter for specific genes if provided
if gene_symbols is not None:
Expand Down
12 changes: 7 additions & 5 deletions sccellfie/expression/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,34 +6,36 @@
from sccellfie.tests.toy_inputs import create_random_adata, create_controlled_adata


def test_agg_expression_cells_all_groups_present():
@pytest.mark.parametrize("use_raw", [False, True])
def test_agg_expression_cells_all_groups_present(use_raw):
adata = create_random_adata()
groupby = 'group'
# Create two groups for simplicity
adata.obs[groupby] = ['group1' if i < adata.n_obs // 2 else 'group2' for i in range(adata.n_obs)]

# Test aggregation across all groups
agg_result = agg_expression_cells(adata, groupby, agg_func='mean')
agg_result = agg_expression_cells(adata, groupby, agg_func='mean', use_raw=use_raw)

# Check if all groups are present in the result
expected_groups = set(adata.obs[groupby].unique())
result_groups = set(agg_result.index)
assert expected_groups == result_groups, "Not all groups are present in the aggregation result"


def test_agg_expression_cells_specific_gene_present():
@pytest.mark.parametrize("use_raw", [False, True])
def test_agg_expression_cells_specific_gene_present(use_raw):
adata = create_random_adata()
groupby = 'group'
adata.obs[groupby] = ['group1' if i < adata.n_obs // 2 else 'group2' for i in range(adata.n_obs)]
gene_symbols = ['gene1', 'gene10']

# Test aggregation with one gene
agg_result = agg_expression_cells(adata, groupby, gene_symbols=gene_symbols[0], agg_func='mean')
agg_result = agg_expression_cells(adata, groupby, gene_symbols=gene_symbols[0], agg_func='mean', use_raw=use_raw)
assert agg_result.shape == (len(adata.obs[groupby].unique()), 1), "Shape mismatch"
assert gene_symbols[0] in agg_result.columns, "Missing gene in result"

# Test aggregation with specific genes
agg_result = agg_expression_cells(adata, groupby, gene_symbols=gene_symbols, agg_func='mean')
agg_result = agg_expression_cells(adata, groupby, gene_symbols=gene_symbols, agg_func='mean', use_raw=use_raw)
assert agg_result.shape == (len(adata.obs[groupby].unique()), len(gene_symbols)), "Shape mismatch"
assert all(gene in agg_result.columns for gene in gene_symbols), "Missing genes in result"

Expand Down
36 changes: 24 additions & 12 deletions sccellfie/expression/thresholds.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
import pandas as pd


def get_local_percentile_threshold(adata, percentile=0.75, lower_bound=0.25, upper_bound=None, exclude_zeros=False):
def get_local_percentile_threshold(adata, percentile=0.75, lower_bound=0.25, upper_bound=None, exclude_zeros=False, use_raw=False):
if use_raw:
X = adata.raw.X.toarray()
else:
X = adata.X.toarray()
if exclude_zeros:
X = np.ma.masked_equal(adata.X.toarray(), 0)
X = np.ma.masked_equal(X, 0)
thresholds = np.quantile(X, q=percentile, axis=0).data
else:
X = adata.X.toarray()
thresholds = np.quantile(X, q=percentile, axis=0)
if isinstance(percentile, list):
columns = ['threshold-{}'.format(p) for p in percentile]
Expand All @@ -33,12 +36,15 @@ def get_local_percentile_threshold(adata, percentile=0.75, lower_bound=0.25, upp
return thresholds


def get_global_percentile_threshold(adata, percentile=0.75, lower_bound=0.25, upper_bound=None, exclude_zeros=False):
def get_global_percentile_threshold(adata, percentile=0.75, lower_bound=0.25, upper_bound=None, exclude_zeros=False, use_raw=False):
if use_raw:
X = adata.raw.X.toarray()
else:
X = adata.X.toarray()
if exclude_zeros:
X = np.ma.masked_equal(adata.X.toarray(), 0)
X = np.ma.masked_equal(X, 0)
thresholds = np.quantile(X, q=percentile)
else:
X = adata.X.toarray()
thresholds = np.quantile(X, q=percentile)
if isinstance(percentile, list):
columns = ['threshold-{}'.format(p) for p in percentile]
Expand All @@ -64,12 +70,15 @@ def get_global_percentile_threshold(adata, percentile=0.75, lower_bound=0.25, up
return thresholds


def get_local_mean_threshold(adata, lower_bound=0.25, upper_bound=None, exclude_zeros=False):
def get_local_mean_threshold(adata, lower_bound=0.25, upper_bound=None, exclude_zeros=False, use_raw=False):
if use_raw:
X = adata.raw.X.toarray()
else:
X = adata.X.toarray()
if exclude_zeros:
X = np.ma.masked_equal(adata.X.toarray(), 0)
X = np.ma.masked_equal(X, 0)
thresholds = np.ma.mean(X, axis=0).data
else:
X = adata.X.toarray()
thresholds = np.nanmean(X, axis=0)
columns = ['threshold-mean']
thresholds = pd.DataFrame(thresholds.T, index=adata.var_names, columns=columns)
Expand All @@ -92,12 +101,15 @@ def get_local_mean_threshold(adata, lower_bound=0.25, upper_bound=None, exclude_
return thresholds


def get_global_mean_threshold(adata, lower_bound=0.25, upper_bound=None, exclude_zeros=False):
def get_global_mean_threshold(adata, lower_bound=0.25, upper_bound=None, exclude_zeros=False, use_raw=False):
if use_raw:
X = adata.raw.X.toarray()
else:
X = adata.X.toarray()
if exclude_zeros:
X = np.ma.masked_equal(adata.X.toarray(), 0)
X = np.ma.masked_equal(X, 0)
thresholds = np.ma.mean(X)
else:
X = adata.X.toarray()
thresholds = np.nanmean(X)
columns = ['threshold-mean']
thresholds = pd.DataFrame(thresholds.T, index=adata.var_names, columns=columns)
Expand Down
8 changes: 5 additions & 3 deletions sccellfie/gene_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ def gene_score(gene_expression, gene_threshold):
return 5*np.log(1 + gene_expression/(gene_threshold+0.01)) # Added small value to threshold to avoid division by zero


def compute_gene_scores(adata, thresholds):
def compute_gene_scores(adata, thresholds, use_raw=False):
genes = [g for g in thresholds.index if g in adata.var_names]

X = adata[:, genes].X.toarray()
if use_raw:
X = adata[:, genes].raw.X.toarray()
else:
X = adata[:, genes].X.toarray()
_thresholds = thresholds.loc[genes, thresholds.columns[:1]] # Use only first column, to avoid issues

gene_scores = gene_score(X, _thresholds.values.T)
Expand Down
22 changes: 11 additions & 11 deletions sccellfie/tests/test_gene_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import numpy as np
import pandas as pd
import scanpy as sc

from cobra.core.gene import GPR

from sccellfie.gene_score import gene_score, compute_gene_scores, compute_gpr_gene_score
from sccellfie.tests.toy_inputs import create_controlled_adata


def test_gene_score():
Expand All @@ -24,23 +24,23 @@ def test_gene_score():
np.testing.assert_allclose(actual_scores, expected_scores, rtol=1e-5)


def test_compute_gene_scores():
@pytest.mark.parametrize("use_raw", [False, True])
def test_compute_gene_scores(use_raw):
# Create a small, controlled AnnData object
gene_expression_data = np.array([
[1, 2], # Cell1
[3, 4], # Cell2
])
adata = sc.AnnData(X=gene_expression_data)
adata.var_names = ['gene1', 'gene2']
adata = create_controlled_adata()
if use_raw:
X = adata.raw.X.toarray()
else:
X = adata.X.toarray()

# Define known thresholds
thresholds = pd.DataFrame({'gene_threshold': [0.1, 0.2]}, index=['gene1', 'gene2'])
thresholds = pd.DataFrame({'gene_threshold': [0.5, 3, 5]}, index=['gene1', 'gene2', 'gene3'])

# Expected gene scores based on the defined formula
expected_scores = 5 * np.log(1 + gene_expression_data / (thresholds.values.T + 0.01))
expected_scores = 5 * np.log(1 + X / (thresholds.values.T + 0.01))

# Compute gene scores using the function
compute_gene_scores(adata, thresholds)
compute_gene_scores(adata, thresholds, use_raw=use_raw)

# Retrieve the computed gene scores from adata
computed_scores = adata.layers['gene_scores']
Expand Down
5 changes: 5 additions & 0 deletions sccellfie/tests/toy_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pandas as pd

import scanpy as sc
from scipy import sparse


def create_random_adata(n_obs=100, n_vars=50, layers=None):
Expand All @@ -10,6 +11,8 @@ def create_random_adata(n_obs=100, n_vars=50, layers=None):
obs = pd.DataFrame(index=[f'cell{i}' for i in range(1, n_obs+1)])
var = pd.DataFrame(index=[f'gene{i}' for i in range(1, n_vars+1)])
adata = sc.AnnData(X=X, obs=obs, var=var)
adata.X = sparse.csr_matrix(adata.X)
adata.raw = adata.copy()

if layers:
if isinstance(layers, str):
Expand All @@ -31,6 +34,8 @@ def create_controlled_adata():
adata.var_names = ['gene1', 'gene2', 'gene3']
adata.obs_names = ['cell1', 'cell2', 'cell3', 'cell4']
adata.obs['group'] = ['A', 'A', 'B', 'B']
adata.X = sparse.csr_matrix(adata.X)
adata.raw = adata.copy()
return adata


Expand Down

0 comments on commit 0b04799

Please sign in to comment.