From c1eb6027dffb5d5c555246bccb0ad9c4ef5950aa Mon Sep 17 00:00:00 2001 From: Zhuoqing Fang Date: Wed, 13 Nov 2024 12:31:17 -0800 Subject: [PATCH] minor --- gseapy/base.py | 2 +- gseapy/msigdb.py | 25 ++++++++++++++++++++++++- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/gseapy/base.py b/gseapy/base.py index 43232f9..e44d82f 100644 --- a/gseapy/base.py +++ b/gseapy/base.py @@ -732,7 +732,7 @@ def enrichment_score( # Test whether each element of a 1-D array is also present in a second array # It's more intuitive here than original enrichment_score source code. # use .astype to covert bool to integer - tag_indicator = np.in1d(gene_list, gene_set, assume_unique=True).astype( + tag_indicator = np.isin(gene_list, gene_set, assume_unique=True).astype( int ) # notice that the sign is 0 (no tag) or 1 (tag) diff --git a/gseapy/msigdb.py b/gseapy/msigdb.py index 4927b8c..0af6781 100644 --- a/gseapy/msigdb.py +++ b/gseapy/msigdb.py @@ -10,11 +10,20 @@ def __init__(self, dbver: str = "2023.1.Hs"): dbver: MSIGDB version number. default: 2023.1.Hs """ self.url = "https://data.broadinstitute.org/gsea-msigdb/msigdb/release/" - self._pattern = re.compile("(\w.+)\.(v\d.+)\.(entrez|symbols)\.gmt") + self._pattern = re.compile(r"(\w.+)\.(v\d.+)\.(entrez|symbols)\.gmt") self._db_version = self._get_db_version() + if self._db_version is None: + raise Exception("Failed to fetch available MSIGDB versions") self.categoires = self.list_category(dbver) def _get_db_version(self): + """ + Get all available MSIGDB versions + + Return: + A pd.DataFrame of all available MSIGDB versions. + If failed to fetch, return None. + """ resp = requests.get(self.url) if resp.ok: d = pd.read_html(resp.text)[0] @@ -48,6 +57,13 @@ def get_gmt( def list_dbver(self): # self._db_version.columns = ["dbver", "date"] + """ + Return a pd.DataFrame of all available MSIGDB versions + + Return: + A pd.DataFrame of all available MSIGDB versions. + If failed to fetch, return None. + """ return self._db_version def list_category(self, dbver: str = "2023.1.Hs"): @@ -66,6 +82,13 @@ def list_category(self, dbver: str = "2023.1.Hs"): return None def list_gmt(self, db: str): + """ + list all gmt files in MSIGDB database. + + :param db: MSIGDB version number. default: 2023.1.Hs + :return: a pandas DataFrame object + """ + url = self.url + db resp = requests.get(url) if resp.ok: