Skip to content

Commit

Permalink
fixed #253
Browse files Browse the repository at this point in the history
  • Loading branch information
zqfang committed Mar 19, 2024
1 parent 3a091d9 commit 0d01bb8
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 15 deletions.
36 changes: 26 additions & 10 deletions gseapy/gsea.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ def __init__(
self.seed = seed
self.ranking = None
self._noplot = no_plot
# some preprocessing
assert self.permutation_type in ["phenotype", "gene_set"]
assert self.min_size <= self.max_size
# phenotype labels parsing
self.load_classes(classes)

Expand Down Expand Up @@ -188,6 +191,9 @@ def calc_metric(
ser = df_mean[pos] - df_mean[neg]
elif method == "log2_ratio_of_classes":
ser = np.log2(df_mean[pos] / df_mean[neg])
if ser.isna().sum() > 0:
self._logger.warning("Invalid value encountered in log2, and dumped.")
ser = ser.dropna()
else:
logging.error("Please provide correct method name!!!")
raise LookupError("Input method: %s is not supported" % method)
Expand All @@ -198,23 +204,36 @@ def calc_metric(
# descending order
return ser_ind[::-1], ser[::-1]

def _check_classes(self, counter: Counter) -> List[str]:
"""
check each cls group length
"""
metrics = ["signal_to_noise", "s2n", "abs_signal_to_noise", "abs_s2n", "t_test"]
s = []
for c, v in sorted(counter.items(), key=lambda item: item[1]):
if v < 3:
if self.permutation_type == "phenotype":
self._logger.warning(
f"Number of {c}: {v}, it must be >= 3 for permutation type: phenotype !"
)
self._logger.warning("Permutation type change to gene_set.")
self.permutation_type == "gene_set"
s.append(c)
return s

def load_classes(self, classes: Union[str, List[str], Dict[str, Any]]):
"""Parse group (classes)"""
if isinstance(classes, dict):
# check number of samples
class_values = Counter(classes.values())
s = []
for c, v in sorted(class_values.items(), key=lambda item: item[1]):
if v < 3:
raise Exception(f"Number of {c}: {v}, it must be >= 3!")
s.append(c)
s = self._check_classes(Counter(classes.values()))
self.pheno_pos = s[0]
self.pheno_neg = s[1]
# n_pos = class_values[pos]
# n_neg = class_values[neg]
self.groups = classes
else:
pos, neg, cls_vector = gsea_cls_parser(classes)
s = self._check_classes(Counter(cls_vector))
self.pheno_pos = pos
self.pheno_neg = neg
self.groups = cls_vector
Expand All @@ -225,7 +244,7 @@ def run(self):
m = self.method.lower()
if m in ["signal_to_noise", "s2n"]:
method = Metric.Signal2Noise
elif m in ["s2n", "abs_signal_to_noise", "abs_s2n"]:
elif m in ["abs_signal_to_noise", "abs_s2n"]:
method = Metric.AbsSignal2Noise
elif m == "t_test":
method = Metric.Ttest
Expand All @@ -238,9 +257,6 @@ def run(self):
else:
raise Exception("Sorry, input method %s is not supported" % m)

assert self.permutation_type in ["phenotype", "gene_set"]
assert self.min_size <= self.max_size

# Start Analysis
self._logger.info("Parsing data files for GSEA.............................")
# select correct expression genes and values.
Expand Down
5 changes: 0 additions & 5 deletions gseapy/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import logging
import os
import xml.etree.ElementTree as ET
from collections import Counter
from collections.abc import Iterable
from typing import Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -42,10 +41,6 @@ def gsea_cls_parser(cls: str) -> Tuple[str]:
if len(sample_name) != 2:
raise Exception("Input groups have to be 2!")

for c, v in Counter(classes).items():
if v < 3:
raise Exception(f"Number of {c}: {v}, it must be >= 3!")

return sample_name[0], sample_name[1], classes


Expand Down

0 comments on commit 0d01bb8

Please sign in to comment.