From 4ae07429fae1e1b2d2ce2b870cadf881f2f15236 Mon Sep 17 00:00:00 2001 From: zqfang Date: Tue, 24 Oct 2023 17:06:33 -0700 Subject: [PATCH] refactor data parsing --- gseapy/gsea.py | 70 +++++++++++++++++++++++++++++++++----------------- 1 file changed, 46 insertions(+), 24 deletions(-) diff --git a/gseapy/gsea.py b/gseapy/gsea.py index 1513a7a..5f3bbff 100644 --- a/gseapy/gsea.py +++ b/gseapy/gsea.py @@ -6,7 +6,7 @@ import os import xml.etree.ElementTree as ET from collections import Counter -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import pandas as pd @@ -51,7 +51,7 @@ def __init__( verbose=verbose, ) self.data = data - self.classes = classes + # self.classes = classes self.permutation_type = permutation_type self.method = method self.min_size = min_size @@ -65,10 +65,12 @@ def __init__( self.seed = seed self.ranking = None self._noplot = no_plot - self.pheno_pos = "pos" - self.pheno_neg = "neg" + # phenotype labels parsing + self.load_classes(classes) - def load_data(self, cls_vec: List[str]) -> Tuple[pd.DataFrame, Dict]: + def load_data( + self, groups: Union[List[str], Dict[str, Any]] + ) -> Tuple[pd.DataFrame, Dict]: """pre-processed the data frame.new filtering methods will be implement here.""" # read data in if isinstance(self.data, pd.DataFrame): @@ -88,6 +90,14 @@ def load_data(self, cls_vec: List[str]) -> Tuple[pd.DataFrame, Dict]: else: raise Exception("Error parsing gene expression DataFrame!") + exprs = self._check_data(exprs) + exprs, cls_dict = self._filter_data(exprs) + return exprs, cls_dict + + def _check_data(self, exprs: pd.DataFrame) -> pd.DataFrame: + """ + check NAs, duplicates. + """ if exprs.isnull().any().sum() > 0: self._logger.warning("Input data contains NA, filled NA with 0") exprs.dropna(how="all", inplace=True) # drop rows with all NAs @@ -102,12 +112,29 @@ def load_data(self, cls_vec: List[str]) -> Tuple[pd.DataFrame, Dict]: "Found duplicated gene names, values averaged by gene names!" ) df = df.groupby(level=0).mean() + return df + def _map_classes(self, sample_names: List[str]) -> Dict[str, Any]: + """ + update + """ + cls_dict = self.groups + if isinstance(self.groups, dict): + # update groups + self.groups = [cls_dict[c] for c in sample_names] + else: + cls_dict = {k: v for k, v in zip(sample_names, self.groups)} + return cls_dict + + def _filter_data(self, df: pd.DataFrame) -> pd.DataFrame: + """ + filter data rows with std == 0 + """ # in case the description column is numeric - if len(cls_vec) == (df.shape[1] - 1): + if len(self.groups) == (df.shape[1] - 1): df = df.iloc[:, 1:] + cls_dict = self._map_classes(df.columns) # drop gene which std == 0 in all samples - cls_dict = {k: v for k, v in zip(df.columns, cls_vec)} # compatible to py3.7 major, minor, _ = [int(i) for i in pd.__version__.split(".")] if (major == 1 and minor < 5) or (major < 1): @@ -120,13 +147,13 @@ def load_data(self, cls_vec: List[str]) -> Tuple[pd.DataFrame, Dict]: return df, cls_dict - def calculate_metric( + def calc_metric( self, df: pd.DataFrame, method: str, pos: str, neg: str, - classes: Dict[str, List[str]], + classes: Dict[str, str], ascending: bool, ) -> Tuple[List[int], pd.Series]: """The main function to rank an expression table. works for 2d array. @@ -210,13 +237,11 @@ def calculate_metric( # descending order return ser_ind[::-1], ser[::-1] - def load_classes( - self, - ): + def load_classes(self, classes: Union[str, List[str], Dict[str, Any]]): """Parse group (classes)""" - if isinstance(self.classes, dict): + if isinstance(classes, dict): # check number of samples - class_values = Counter(self.classes.values()) + class_values = Counter(classes.values()) s = [] for c, v in sorted(class_values.items(), key=lambda item: item[1]): if v < 3: @@ -226,12 +251,12 @@ def load_classes( self.pheno_neg = s[1] # n_pos = class_values[pos] # n_neg = class_values[neg] - return + self.groups = classes else: - pos, neg, cls_vector = gsea_cls_parser(self.classes) + pos, neg, cls_vector = gsea_cls_parser(classes) self.pheno_pos = pos self.pheno_neg = neg - return cls_vector + self.groups = cls_vector # @profile def run(self): @@ -257,11 +282,8 @@ def run(self): # Start Analysis self._logger.info("Parsing data files for GSEA.............................") - # phenotype labels parsing - cls_vector = self.load_classes() # select correct expression genes and values. - dat, cls_dict = self.load_data(cls_vector) - self.cls_dict = cls_dict + dat, cls_dict = self.load_data(self.groups) # data frame must have length > 1 assert len(dat) > 1 # filtering out gene sets and build gene sets dictionary @@ -275,7 +297,7 @@ def run(self): # compute ES, NES, pval, FDR, RES if self.permutation_type == "gene_set": # ranking metrics calculation. - idx, dat2 = self.calculate_metric( + idx, dat2 = self.calc_metric( df=dat, method=self.method, pos=self.pheno_pos, @@ -301,7 +323,7 @@ def run(self): gsum.indices = indices # only accept [[]] else: # phenotype permutation group = list( - map(lambda x: True if x == self.pheno_pos else False, cls_vector) + map(lambda x: True if x == self.pheno_pos else False, self.groups) ) gsum = gsea_rs( dat.index.values.tolist(), @@ -324,7 +346,7 @@ def run(self): self.ranking = pd.Series(gsum.rankings[0], index=dat.index[gsum.indices[0]]) # reorder datarame for heatmap # self._heatmat(df=dat.loc[dat2.index], classes=cls_vector) - self._heatmat(df=dat.iloc[gsum.indices[0]], classes=cls_vector) + self._heatmat(df=dat.iloc[gsum.indices[0]], classes=self.groups) # write output and plotting self.to_df(gsum.summaries, gmt, self.ranking) self._logger.info("Congratulations. GSEApy ran successfully.................\n")