diff --git a/gseapy/gsea.py b/gseapy/gsea.py index 3610d55..b910eba 100644 --- a/gseapy/gsea.py +++ b/gseapy/gsea.py @@ -308,15 +308,15 @@ def load_classes(self, classes: Union[str, List[str], Dict[str, Any]]): self.pheno_neg = neg self.groups = cls_vector - if self._outdir is not None: - self.to_cls(outdir=self.outdir) - def to_cls(self, outdir: str): """Save group information to cls file""" with open(os.path.join(outdir, "group.cls"), "w") as f: f.write(f"{len(self.groups)} 2 1\n") f.write(f"# {self.pheno_pos} {self.pheno_neg}\n") - f.write(" ".join(self.groups) + "\n") + if isinstance(self.groups, dict): + f.write(" ".join(list(self.groups.values())) + "\n") + else: + f.write(" ".join(self.groups) + "\n") # @profile def run(self): @@ -411,6 +411,7 @@ def run(self): # write output and plotting self.to_df(gsum.summaries, gmt, self.ranking) if self._outdir is not None: + self.to_cls(outdir=self.outdir) self.ranking.to_csv( os.path.join(self.outdir, "gsea_data.rnk"), sep="\t", header=False )