Skip to content

Commit

Permalink
making it fail
Browse files Browse the repository at this point in the history
  • Loading branch information
jkobject committed Oct 23, 2024
1 parent 55fe2ca commit f37a7f9
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 43 deletions.
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ virtualenv: ## Create a virtual environment.
@uv venv
@source .venv/bin/activate
@make install
@echo "!!! Please run 'source .venv/bin/activate' to enable the environment !!!"

.PHONY: release
release: ## Create a new tag for release.
Expand Down
22 changes: 13 additions & 9 deletions bengrn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from ctxcore.rnkdb import FeatherRankingDatabase as RankingDatabase
from grnndata import GRNAnnData, from_adata_and_longform, from_scope_loomfile, utils
from scipy.sparse import csc_matrix, csr_matrix
from scipy.sparse import issparse
from sklearn.linear_model import LogisticRegression, RidgeClassifier
from sklearn.metrics import PrecisionRecallDisplay, auc, precision_recall_curve
from sklearn.model_selection import train_test_split
Expand Down Expand Up @@ -203,9 +204,10 @@ def scprint_benchmark(self, elems=["Central", "Regulators", "Targets"]):
except KeyError:
pass
istrue = metrics.get("TF_enr", False)
istrue = istrue or (
res.res2d.loc[res.res2d.Term == "0__TFs", "FDR q-val"].iloc[0] < 0.1
)
if len(res.res2d.loc[res.res2d.Term == "0__TFs"]) > 0:
istrue = istrue or (
res.res2d.loc[res.res2d.Term == "0__TFs", "FDR q-val"].iloc[0] < 0.1
)
metrics.update({"TF_enr": istrue})
if self.doplot:
print("_________________________________________")
Expand Down Expand Up @@ -608,6 +610,7 @@ def get_perturb_gt(
GRNAnnData: The Gene Regulatory Network data as a GRNAnnData object.
"""
if not os.path.exists(filename_bh):
os.makedirs(os.path.dirname(filename_bh), exist_ok=True)
urllib.request.urlretrieve(url_bh, filename_bh)
pert = pd.read_csv(filename_bh)
pert = pert.set_index("Unnamed: 0").T
Expand Down Expand Up @@ -711,7 +714,7 @@ def compute_genie3(
Returns:
GRNAnnData: The Gene Regulatory Network data computed using the GENIE3 algorithm.
"""
mat = np.asarray(adata.X.todense())
mat = np.asarray(adata.X.toarray() if issparse(adata.X) else adata.X)
names = adata.var_names[mat.sum(0) > 0].tolist()
var = adata.var[mat.sum(0) > 0]
mat = mat[:, mat.sum(0) > 0]
Expand Down Expand Up @@ -759,6 +762,9 @@ def get_GT_db(
net = dc.get_dorothea(organism=organism)
elif name == "omnipath":
if not os.path.exists(FILEDIR + "/../data/omnipath.parquet"):
os.makedirs(
os.path.dirname(FILEDIR + "/../data/omnipath.parquet"), exist_ok=True
)
from omnipath.interactions import AllInteractions
from omnipath.requests import Annotations

Expand Down Expand Up @@ -992,7 +998,6 @@ def load_genes(organisms: Union[str, list] = "NCBITaxon:9606"): # "NCBITaxon:10
genesdf = bt.Gene.filter(
organism_id=bt.Organism.filter(ontology_id=organism).first().id
).df()
genesdf = genesdf[~genesdf["public_source_id"].isna()]
genesdf = genesdf.drop_duplicates(subset="ensembl_gene_id")
genesdf = genesdf.set_index("ensembl_gene_id").sort_index()
# mitochondrial genes
Expand All @@ -1004,8 +1009,7 @@ def load_genes(organisms: Union[str, list] = "NCBITaxon:9606"): # "NCBITaxon:10
genesdf["organism"] = organism
organismdf.append(genesdf)
organismdf = pd.concat(organismdf)
organismdf.drop(
columns=["source_id", "run_id", "created_by_id", "updated_at", "stable_id"],
inplace=True,
)
for col in ["source_id", "run_id", "created_by_id", "updated_at", "stable_id"]:
if col in organismdf.columns:
organismdf.drop(columns=[col], inplace=True)
return organismdf
14 changes: 11 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "benGRN"
version = "1.2.2"
version = "1.2.4"
description = "benchmarking gene regulatory networks"
authors = [
{name = "jeremie kalfon", email = "[email protected]"}
Expand All @@ -20,13 +20,15 @@ dependencies = [
"seaborn>=0.11.0",
"decoupler>=1.2.0",
"pandas>=2.0.0",
"grnndata>=0.1.0",
"grnndata>=1.1.4",
"omnipath>=1.0.0",
"dask-expr>=1.0.0",
"gseapy>=0.10.0",
"bionty>=0.49.0",
"rich>=13.5.0",
"gdown>=4.7.1",
"setuptools>=58.0.0",
"numba>=0.56.0",
]

[project.optional-dependencies]
Expand Down Expand Up @@ -56,4 +58,10 @@ ignore = ["E501", "E203", "E266", "E265", "F401", "F403"]

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
build-backend = "hatchling.build"

[tool.hatch.build.targets.wheel]
include = [
"/bengrn",
"/data",
]
58 changes: 27 additions & 31 deletions tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,34 +26,30 @@ def test_base():
adata = adata[:, np.argsort(-adata.X.sum(axis=0)).tolist()[0][:1000]]
random_mask = np.random.choice([0, 1], size=random_matrix.shape, p=[0.8, 0.2])
sparse_random_matrix = csr_matrix(random_matrix * random_mask)
try:
grn = GRNAnnData(adata.copy(), grn=sparse_random_matrix)
grn.var.index = grn.var.symbol.astype(str)
_ = BenGRN(grn, doplot=False).scprint_benchmark()

# Test get_sroy_gt function
sroy_gt = get_sroy_gt(get="liu")
assert isinstance(
sroy_gt, GRNAnnData
), "get_sroy_gt should return a GRNAnnData object"

# Test get_perturb_gt function
perturb_gt = get_perturb_gt()
assert isinstance(
perturb_gt, GRNAnnData
), "get_perturb_gt should return a GRNAnnData object"

# Test compute_genie3 function
genie3_result = compute_genie3(adata[:, :100], ntrees=10, nthreads=1)
assert isinstance(
genie3_result, GRNAnnData
), "compute_genie3 should return a GRNAnnData object"

# Test train_classifier function
random_matrix = np.random.rand(4, 10000).reshape(100, 100, 4)
subgrn = grn[:, :100]
subgrn.varp["GRN"] = random_matrix
classifier, metrics, clf = train_classifier(subgrn)

except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
grn = GRNAnnData(adata.copy(), grn=sparse_random_matrix)
grn.var.index = grn.var.symbol.astype(str)
_ = BenGRN(grn, doplot=False).scprint_benchmark()

# Test get_sroy_gt function
sroy_gt = get_sroy_gt(get="liu")
assert isinstance(
sroy_gt, GRNAnnData
), "get_sroy_gt should return a GRNAnnData object"

# Test get_perturb_gt function
perturb_gt = get_perturb_gt()
assert isinstance(
perturb_gt, GRNAnnData
), "get_perturb_gt should return a GRNAnnData object"

# Test compute_genie3 function
genie3_result = compute_genie3(adata[:, :100], ntrees=10, nthreads=1)
assert isinstance(
genie3_result, GRNAnnData
), "compute_genie3 should return a GRNAnnData object"

# Test train_classifier function
random_matrix = np.random.rand(4, 10000).reshape(100, 100, 4)
subgrn = grn[:, :100]
subgrn.varp["GRN"] = random_matrix
classifier, metrics, clf = train_classifier(subgrn)

0 comments on commit f37a7f9

Please sign in to comment.