diff --git a/gseapy/gsea.py b/gseapy/gsea.py index 7fea1e0..4e5375b 100644 --- a/gseapy/gsea.py +++ b/gseapy/gsea.py @@ -65,6 +65,9 @@ def __init__( self.seed = seed self.ranking = None self._noplot = no_plot + # some preprocessing + assert self.permutation_type in ["phenotype", "gene_set"] + assert self.min_size <= self.max_size # phenotype labels parsing self.load_classes(classes) @@ -188,6 +191,9 @@ def calc_metric( ser = df_mean[pos] - df_mean[neg] elif method == "log2_ratio_of_classes": ser = np.log2(df_mean[pos] / df_mean[neg]) + if ser.isna().sum() > 0: + self._logger.warning("Invalid value encountered in log2, and dumped.") + ser = ser.dropna() else: logging.error("Please provide correct method name!!!") raise LookupError("Input method: %s is not supported" % method) @@ -198,16 +204,28 @@ def calc_metric( # descending order return ser_ind[::-1], ser[::-1] + def _check_classes(self, counter: Counter) -> List[str]: + """ + check each cls group length + """ + metrics = ["signal_to_noise", "s2n", "abs_signal_to_noise", "abs_s2n", "t_test"] + s = [] + for c, v in sorted(counter.items(), key=lambda item: item[1]): + if v < 3: + if self.permutation_type == "phenotype": + self._logger.warning( + f"Number of {c}: {v}, it must be >= 3 for permutation type: phenotype !" + ) + self._logger.warning("Permutation type change to gene_set.") + self.permutation_type == "gene_set" + s.append(c) + return s + def load_classes(self, classes: Union[str, List[str], Dict[str, Any]]): """Parse group (classes)""" if isinstance(classes, dict): # check number of samples - class_values = Counter(classes.values()) - s = [] - for c, v in sorted(class_values.items(), key=lambda item: item[1]): - if v < 3: - raise Exception(f"Number of {c}: {v}, it must be >= 3!") - s.append(c) + s = self._check_classes(Counter(classes.values())) self.pheno_pos = s[0] self.pheno_neg = s[1] # n_pos = class_values[pos] @@ -215,6 +233,7 @@ def load_classes(self, classes: Union[str, List[str], Dict[str, Any]]): self.groups = classes else: pos, neg, cls_vector = gsea_cls_parser(classes) + s = self._check_classes(Counter(cls_vector)) self.pheno_pos = pos self.pheno_neg = neg self.groups = cls_vector @@ -225,7 +244,7 @@ def run(self): m = self.method.lower() if m in ["signal_to_noise", "s2n"]: method = Metric.Signal2Noise - elif m in ["s2n", "abs_signal_to_noise", "abs_s2n"]: + elif m in ["abs_signal_to_noise", "abs_s2n"]: method = Metric.AbsSignal2Noise elif m == "t_test": method = Metric.Ttest @@ -238,9 +257,6 @@ def run(self): else: raise Exception("Sorry, input method %s is not supported" % m) - assert self.permutation_type in ["phenotype", "gene_set"] - assert self.min_size <= self.max_size - # Start Analysis self._logger.info("Parsing data files for GSEA.............................") # select correct expression genes and values. diff --git a/gseapy/parser.py b/gseapy/parser.py index 849325f..f2b4102 100644 --- a/gseapy/parser.py +++ b/gseapy/parser.py @@ -3,7 +3,6 @@ import logging import os import xml.etree.ElementTree as ET -from collections import Counter from collections.abc import Iterable from typing import Dict, List, Optional, Tuple, Union @@ -42,10 +41,6 @@ def gsea_cls_parser(cls: str) -> Tuple[str]: if len(sample_name) != 2: raise Exception("Input groups have to be 2!") - for c, v in Counter(classes).items(): - if v < 3: - raise Exception(f"Number of {c}: {v}, it must be >= 3!") - return sample_name[0], sample_name[1], classes