Skip to content

Commit

Permalink
more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jkobject committed Oct 9, 2024
1 parent 797e2aa commit e0e9f70
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 10 deletions.
10 changes: 5 additions & 5 deletions bengrn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@

import json
import logging
import os
import os.path
import tarfile
import urllib.request
from typing import Optional, Union

import gdown
import gseapy as gp
import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -29,9 +32,6 @@
from sklearn.linear_model import LogisticRegression, RidgeClassifier
from sklearn.metrics import PrecisionRecallDisplay, auc, precision_recall_curve
from sklearn.model_selection import train_test_split
import gdown
import os
import tarfile

from .tools import GENIE3

Expand Down Expand Up @@ -409,13 +409,13 @@ def get_scenicplus(


def get_sroy_gt(
get: str = "main", join: str = "outer", species: str = "human", gt: str = "full"
get: str = "mine", join: str = "outer", species: str = "human", gt: str = "full"
) -> GRNAnnData:
"""
This function retrieves the ground truth data from the McCall et al.'s paper.
Args:
get (str): The specific dataset to retrieve. Options include "main", "liu", and "chen".
get (str): The specific dataset to retrieve. Options include "mine", "liu", and "chen".
join (str, optional): The type of join to be performed when concatenating the data. Default is "outer".
species (str, optional): The species of the dataset. Default is "human".
gt (str, optional): The type of ground truth data to retrieve. Options include "full", "chip", and "ko". Default is "full".
Expand Down
2 changes: 1 addition & 1 deletion bengrn/tools/genie3.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def compute_feature_importances(estimator):
for e in estimator.estimators_
]
importances = array(importances)
return sum(importances, axis=0) / len(estimator)
return importances.sum(0) / len(estimator.estimators_)


def get_link_list(
Expand Down
37 changes: 36 additions & 1 deletion tests/test_base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
import os

import numpy as np
import pandas as pd
import pytest
import scanpy as sc
from grnndata import GRNAnnData
from scipy.sparse import csr_matrix

from bengrn.base import NAME, BenGRN
from bengrn.base import (
NAME,
BenGRN,
compute_epr,
compute_genie3,
get_GT_db,
get_perturb_gt,
get_sroy_gt,
train_classifier,
)


def test_base():
Expand All @@ -20,5 +30,30 @@ def test_base():
grn = GRNAnnData(adata.copy(), grn=sparse_random_matrix)
grn.var.index = grn.var.symbol.astype(str)
_ = BenGRN(grn, doplot=False).scprint_benchmark()

# Test get_sroy_gt function
sroy_gt = get_sroy_gt(get="liu")
assert isinstance(
sroy_gt, GRNAnnData
), "get_sroy_gt should return a GRNAnnData object"

# Test get_perturb_gt function
perturb_gt = get_perturb_gt()
assert isinstance(
perturb_gt, GRNAnnData
), "get_perturb_gt should return a GRNAnnData object"

# Test compute_genie3 function
genie3_result = compute_genie3(adata[:, :100], ntrees=10, nthreads=1)
assert isinstance(
genie3_result, GRNAnnData
), "compute_genie3 should return a GRNAnnData object"

# Test train_classifier function
random_matrix = np.random.rand(4, 10000).reshape(100, 100, 4)
subgrn = grn[:, :100]
subgrn.varp["GRN"] = random_matrix
classifier, metrics, clf = train_classifier(subgrn)

except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
51 changes: 48 additions & 3 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit e0e9f70

Please sign in to comment.