Skip to content

Commit

Permalink
minor
Browse files Browse the repository at this point in the history
  • Loading branch information
zqfang committed Oct 22, 2023
1 parent 8a518b4 commit 1a29be9
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 41 deletions.
32 changes: 19 additions & 13 deletions gseapy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,8 @@ def _load_ranking(self, rnk: Union[pd.DataFrame, pd.Series, str]) -> pd.Series:
if rank_metric.select_dtypes(np.number).shape[1] > 1:
return rank_metric
# sort ranking values from high to low
rank_metric.sort_values(
by=rank_metric.columns[1], ascending=self.ascending, inplace=True
)
rnk_cols = rank_metric.columns
rank_metric.sort_values(by=rnk_cols[1], ascending=self.ascending, inplace=True)
# drop na values
if rank_metric.isnull().any(axis=1).sum() > 0:
self._logger.warning(
Expand All @@ -177,16 +176,23 @@ def _load_ranking(self, rnk: Union[pd.DataFrame, pd.Series, str]) -> pd.Series:
self._logger.debug("NAs list:\n" + NAs.to_string())
rank_metric.dropna(how="any", inplace=True)
# drop duplicate IDs, keep the first
if rank_metric.duplicated(subset=rank_metric.columns[0]).sum() > 0:
self._logger.warning(
"Input gene rankings contains duplicated IDs, Only use the duplicated ID with highest value!"
if rank_metric.duplicated(subset=rnk_cols[0]).sum() > 0:
self._logger.info("Input gene rankings contains duplicated IDs")
mask = rank_metric.duplicated(subset=rnk_cols[0]).duplicated(keep=False)
dups = (
rank_metric[mask]
.groupby(rnk_cols[0])
.cumcount()
.map(lambda c: "_" + str(c) if c else "")
)
# print out duplicated IDs.
dups = rank_metric[rank_metric.duplicated(subset=rank_metric.columns[0])]
self._logger.debug("Dups list:\n" + dups.to_string())
rank_metric.drop_duplicates(
subset=rank_metric.columns[0], inplace=True, keep="first"
rank_metric.loc[mask, rnk_cols[0]] = (
rank_metric.loc[mask, rnk_cols[0]] + dups
)
# dups = rank_metric[rank_metric.duplicated(subset=rnk_cols[0])]
# self._logger.debug("Dups list:\n" + dups.to_string())
# rank_metric.drop_duplicates(
# subset=rank_metric.columns[0], inplace=True, keep="first"
# )
# reset ranking index, because you have sort values and drop duplicates.
rank_metric.reset_index(drop=True, inplace=True)
rank_metric.columns = ["gene_name", "prerank"]
Expand Down Expand Up @@ -609,7 +615,7 @@ def to_df(
res_df["NES"].abs().sort_values(ascending=False).index
).reset_index(drop=True)
res_df.drop(dc, axis=1, inplace=True)

if self._outdir is not None:
out = os.path.join(
self.outdir,
Expand Down Expand Up @@ -751,7 +757,7 @@ def plot(
ofname: savefig
"""
# if hasattr(self, "results"):
if self.module == "ssgsea":
if self.module in ["ssgsea", "gsva"]:
raise NotImplementedError("not for ssgsea")
keys = list(self._results.keys())
if len(keys) > 1:
Expand Down
26 changes: 12 additions & 14 deletions gseapy/gsea.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,6 @@ def load_data(self, cls_vec: List[str]) -> Tuple[pd.DataFrame, Dict]:
else:
raise Exception("Error parsing gene expression DataFrame!")

# drop duplicated gene names
if exprs.iloc[:, 0].duplicated().sum() > 0:
self._logger.warning(
"Dropping duplicated gene names, only keep the first values"
)
# drop duplicate gene_names.
exprs.drop_duplicates(subset=exprs.columns[0], inplace=True)
if exprs.isnull().any().sum() > 0:
self._logger.warning("Input data contains NA, filled NA with 0")
exprs.dropna(how="all", inplace=True) # drop rows with all NAs
Expand All @@ -113,6 +106,12 @@ def load_data(self, cls_vec: List[str]) -> Tuple[pd.DataFrame, Dict]:
# select numberic columns
df = exprs.select_dtypes(include=[np.number])

if exprs.index.duplicated().sum() > 0:
self._logger.warning(
"Found duplicated gene names, values averaged by gene names!"
)
exprs = exprs.groupby(level=0).mean()

# in case the description column is numeric
if len(cls_vec) == (df.shape[1] - 1):
df = df.iloc[:, 1:]
Expand Down Expand Up @@ -581,16 +580,16 @@ def load_data(self) -> pd.DataFrame:
rank_metric = rank_metric.select_dtypes(include=[np.number])
else:
raise Exception("Error parsing gene ranking values!")
if rank_metric.index.duplicated().sum() > 0:
self._logger.warning(
"Dropping duplicated gene names, values averaged by gene names!"
)
rank_metric = rank_metric.loc[rank_metric.index.dropna()]
rank_metric = rank_metric.groupby(level=0).mean()

if rank_metric.isnull().any().sum() > 0:
self._logger.warning("Input data contains NA, filled NA with 0")
rank_metric = rank_metric.fillna(0)

if rank_metric.index.duplicated().sum() > 0:
self._logger.warning(
"Found duplicated gene names, values averaged by gene names!"
)
rank_metric = rank_metric.groupby(level=0).mean()
return rank_metric

def norm_samples(self, dat: pd.DataFrame) -> pd.DataFrame:
Expand Down Expand Up @@ -679,7 +678,6 @@ def runSamplesPermu(
return



class Replot(GSEAbase):
"""To reproduce GSEA desktop output results."""

Expand Down
23 changes: 9 additions & 14 deletions gseapy/gsva.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ def __init__(
self.ranking = None
self.permutation_num = 0
self._noplot = True
if kcdf == "Gaussian":
if kcdf in ["Gaussian", "gaussian"]:
self.kernel = True
self.rnaseq = False
elif kcdf == "Poisson":
elif kcdf in ["Poisson", "poisson"]:
self.kernel = True
self.rnaseq = True
else:
Expand Down Expand Up @@ -106,15 +106,16 @@ def load_data(self) -> pd.DataFrame:
rank_metric = rank_metric.select_dtypes(include=[np.number])
else:
raise Exception("Error parsing gene ranking values!")

if rank_metric.isnull().any().sum() > 0:
self._logger.warning("Input data contains NA, filled NA with 0")
rank_metric = rank_metric.fillna(0)

if rank_metric.index.duplicated().sum() > 0:
self._logger.warning(
"Dropping duplicated gene names, values averaged by gene names!"
"Found duplicated gene names, values averaged by gene names!"
)
rank_metric = rank_metric.loc[rank_metric.index.dropna()]
rank_metric = rank_metric.groupby(level=0).mean()
if rank_metric.isnull().any().sum() > 0:
self._logger.warning("Input data contains NA, filled NA with 0")
rank_metric = rank_metric.fillna(0)

return rank_metric

Expand All @@ -125,13 +126,7 @@ def run(self):
# load data
df = self.load_data()
if self.rnaseq:
self._logger.debug(
"Poisson kernel selected. round input values to intergers!"
)
df = df.astype(int)
self._logger.debug(
"Poisson kernel selected. convert negative values to 0 !"
)
self._logger.info("Poisson kernel selected. Clip negative values to 0 !")
df = df.clip(lower=0)

self.data = df
Expand Down

0 comments on commit 1a29be9

Please sign in to comment.