Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
zqfang committed Nov 21, 2023
1 parent ed2516c commit 4d2540f
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 45 deletions.
37 changes: 26 additions & 11 deletions gseapy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,8 @@ def _check_data(self, exprs: pd.DataFrame) -> pd.DataFrame:
"""
check NAs, duplicates.
exprs: dataframe, the frist column must be gene identifiers
return: dataframe, index is gene ids
"""
## if gene names contain NA, drop them
if exprs.iloc[:, 0].isnull().any():
Expand All @@ -212,30 +214,43 @@ def _check_data(self, exprs: pd.DataFrame) -> pd.DataFrame:
exprs.dropna(how="all", inplace=True) # drop rows with all NAs
exprs = exprs.fillna(0)
## check duplicated IDs
# gene_id = exprs.columns[0]
# if exprs.duplicated(subset=gene_id).sum() > 0:
# self._logger.info("Found duplicated gene names, make unique")
# mask = exprs.duplicated(subset=gene_id, keep=False) #
# dups = exprs.loc[mask, gene_id].groupby(gene_id).cumcount().map(lambda c: "_" + str(c) if c else "")
# exprs.loc[mask, gene_id] = exprs.loc[mask, gene_id] + dups
# check whether contains infinity values

# set gene name as index
exprs.set_index(keys=exprs.columns[0], inplace=True)
# select numberic columns
df = exprs.select_dtypes(include=[np.number])
# microarray data may contained multiple probs of same gene, average them
if exprs.index.duplicated().sum() > 0:
if df.index.duplicated().sum() > 0:
self._logger.warning(
"Found duplicated gene names, values averaged by gene names!"
)
df = df.groupby(level=0).mean()

# check whether contains infinity values
if np.isinf(df).values.sum() > 0:
self._logger.warning("Input gene rankings contains inf values!")
df = df.apply()
col_min_max = {
np.inf: df[np.isfinite(df)].max(), # column-wise max
-np.inf: df[np.isfinite(df)].min(), # column-wise min
}
df = df.replace({col: col_min_max for col in df.columns})
return df

def make_unique(self, rank_metric: pd.DataFrame, col_idx: int) -> pd.DataFrame:
"""
make gene id column unique
"""
id_col = rank_metric.columns[col_idx]
if rank_metric.duplicated(subset=id_col).sum() > 0:
self._logger.info("Input gene rankings contains duplicated IDs")
mask = rank_metric.duplicated(subset=id_col, keep=False)
dups = (
rank_metric.loc[mask, id_col]
.groupby(id_col)
.cumcount()
.map(lambda c: "_" + str(c) if c else "")
)
rank_metric.loc[mask, id_col] = rank_metric.loc[mask, id_col] + dups
return rank_metric

def load_gmt_only(
self, gmt: Union[List[str], str, Dict[str, str]]
) -> Dict[str, List[str]]:
Expand Down
35 changes: 1 addition & 34 deletions gseapy/gsea.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,23 +361,6 @@ def __init__(
self._noplot = no_plot
self.permutation_type = "gene_set"

def make_unique(self, rank_metric: pd.DataFrame, col_idx: int) -> pd.DataFrame:
"""
make gene id column unique
"""
id_col = rank_metric.columns[col_idx]
if rank_metric.duplicated(subset=id_col).sum() > 0:
self._logger.info("Input gene rankings contains duplicated IDs")
mask = rank_metric.duplicated(subset=id_col, keep=False)
dups = (
rank_metric.loc[mask, id_col]
.groupby(id_col)
.cumcount()
.map(lambda c: "_" + str(c) if c else "")
)
rank_metric.loc[mask, id_col] = rank_metric.loc[mask, id_col] + dups
return rank_metric

def _load_ranking(self, rank_metric: pd.DataFrame) -> pd.Series:
"""Parse ranking
rank_metric: two column dataframe. first column is gene ids
Expand Down Expand Up @@ -439,23 +422,7 @@ def load_ranking(self):
# make unique
rank_metric = self.make_unique(rank_metric, col_idx=0)
# set index
rank_metric.set_index(keys=rank_metric.columns[0], inplace=True)
if rank_metric.isnull().any().sum() > 0:
self._logger.warning("Input rankings contains NA values!")
# fill na
rank_metric.dropna(how="all", inplace=True)
rank_metric.fillna(0, inplace=True)

# check whether contains infinity values
if np.isinf(rank_metric).values.sum() > 0:
self._logger.warning("Input gene rankings contains inf values!")
col_min_max = {
np.inf: rank_metric[np.isfinite(rank_metric)].max(), # column-wise max
-np.inf: rank_metric[np.isfinite(rank_metric)].min(), # column-wise min
}
rank_metric = rank_metric.replace(
{col: col_min_max for col in rank_metric.columns}
)
rank_metric = self._check_data(rank_metric)
# check ties in prerank stats
dups = rank_metric.apply(lambda df: df.duplicated().sum() / df.size)
if (dups > 0).sum() > 0:
Expand Down

0 comments on commit 4d2540f

Please sign in to comment.