From fd916cc3e80439b59d60d6e568dd0393703dfafc Mon Sep 17 00:00:00 2001 From: Zhuoqing Fang Date: Wed, 13 Nov 2024 12:30:48 -0800 Subject: [PATCH] add ax keyword, #285 --- gseapy/plot.py | 107 ++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 87 insertions(+), 20 deletions(-) diff --git a/gseapy/plot.py b/gseapy/plot.py index a35ecd6..0588da5 100644 --- a/gseapy/plot.py +++ b/gseapy/plot.py @@ -69,6 +69,7 @@ def __init__( xticklabels: bool = True, yticklabels: bool = True, ofname: Optional[str] = None, + ax: Optional[plt.Axes] = None, **kwargs, ): self.title = "" if title is None else title @@ -76,6 +77,7 @@ def __init__( self.xticklabels = xticklabels self.yticklabels = yticklabels self.ofname = ofname + self.ax = ax # scale dataframe df = df.astype(float) @@ -116,9 +118,31 @@ def _auto_ticks(self, ax, labels, axis): return tickevery def get_ax(self): - if hasattr(sys, "ps1") and (self.ofname is None): + """ + Return a matplotlib axes object. + + If an axes is already set, return it. Otherwise, create a new + figure and axes instance and set it as the current axes. + + Parameters + ---------- + None + + Returns + ------- + ax : matplotlib axes + The current axes object. + """ + if (self.ax is not None) and isinstance(self.ax, plt.Axes): + self.fig = self.ax.figure + return self.ax + elif hasattr(sys, "ps1") and (self.ofname is None): + # Working in an interactive environment, create a figure + # and show it. fig = plt.figure(figsize=self.figsize) else: + # Working non-interactively, create a figure but don't show + # it. fig = Figure(figsize=self.figsize) canvas = FigureCanvas(fig) ax = fig.add_subplot(111) @@ -184,6 +208,7 @@ def heatmap( xticklabels: bool = True, yticklabels: bool = True, ofname: Optional[str] = None, + ax: Optional[plt.Axes] = None, **kwargs, ): """Visualize the dataframe. @@ -195,11 +220,25 @@ def heatmap( :param cmap: matplotlib colormap. e.g. "RdBu_r". :param xticklabels: bool, whether to show xticklabels. :param xticklabels: bool, whether to show xticklabels. - :param ofname: output file name. If None, don't save figure + :param ofname: output file name. If None, don't save figure. + :param ax: matplotlib axes. Default: None. + + :return: ax if ofname is None. """ - ht = Heatmap(df, z_score, title, figsize, cmap, xticklabels, yticklabels, ofname) + ht = Heatmap( + df=df, + z_score=z_score, + title=title, + figsize=figsize, + cmap=cmap, + xticklabels=xticklabels, + yticklabels=yticklabels, + ofname=ofname, + ax=ax, + **kwargs, + ) ax = ht.draw() if ofname is None: return ax @@ -237,7 +276,8 @@ def __init__( :param pheno_pos: phenotype label, positive correlated. :param pheno_neg: phenotype label, negative correlated. :param figsize: matplotlib figsize. - :param ofname: output file name. If None, don't save figure + :param ofname: output file name. If None, don't save figure. + :param ax: matplotlib axes. Default: None. """ # dataFrame of ranked matrix scores self.color = "#88C544" if color is None else color @@ -284,9 +324,10 @@ def __init__( # If working on command line, don't show figure self.fig = Figure(figsize=self.figsize, facecolor="white") self._canvas = FigureCanvas(self.fig) - else: + elif isinstance(ax, plt.Axes): self.fig = ax.figure - + else: + raise ValueError("ax must be matplotlib axes or None") self.fig.suptitle(self.term, fontsize=16, wrap=True, fontweight="bold") def axes_rank(self, rect): @@ -598,9 +639,11 @@ def __init__( thresh: float = 0.05, n_terms: int = 10, title: str = "", + ax: Optional[plt.Axes] = None, figsize: Tuple[float, float] = (6, 5.5), cmap: str = "viridis_r", ofname: Optional[str] = None, + marker: str = "o", **kwargs, ): """Visualize GSEApy Results with categorical scatterplot @@ -625,14 +668,13 @@ def __init__( ("Adjusted P-value", "P-value", "NOM p-val", "FDR q-val") :param n_terms: Number of enriched terms to show. :param dot_scale: float, scale the dot size to get proper visualization. - :param figsize: tuple, matplotlib figure size. + :param ax: Matplotlib axes. Default: None. + :param figsize: tuple, matplotlib figure size, only used when `ax` is None. :param cmap: Matplotlib colormap for mapping the `column` semantic. :param ofname: Output file name. If None, don't save figure :param marker: The matplotlib.markers. See https://matplotlib.org/stable/api/markers_api.html """ - self.marker = "o" - if "marker" in kwargs: - self.marker = kwargs["marker"] + self.marker = marker self.y = y self.x = x self.x_order = x_order @@ -642,6 +684,7 @@ def __init__( self.figsize = figsize self.cmap = cmap self.ofname = ofname + self.ax = ax self.scale = dot_scale self.title = title self.n_terms = n_terms @@ -773,16 +816,32 @@ def get_y_order( def get_ax(self): """ - setup figure axes + Return a matplotlib axes object. + + If an axes is already set, return it. Otherwise, create a new + figure and axes instance and set it as the current axes. + + Parameters + ---------- + None + + Returns + ------- + ax : matplotlib axes + The current axes object. """ - # create fig - if hasattr(sys, "ps1") and (self.ofname is None): - # working inside python console, show figure + if (self.ax is not None) and isinstance(self.ax, plt.Axes): + self.fig = self.ax.figure + return self.ax + elif hasattr(sys, "ps1") and (self.ofname is None): + # Working in an interactive environment, create a figure + # and show it. fig = plt.figure(figsize=self.figsize) else: - # If working on commandline, don't show figure + # Working non-interactively, create a figure but don't show + # it. fig = Figure(figsize=self.figsize) - _canvas = FigureCanvas(fig) + canvas = FigureCanvas(fig) ax = fig.add_subplot(111) self.fig = fig return ax @@ -1119,6 +1178,7 @@ def dotplot( cutoff: float = 0.05, top_term: int = 10, size: float = 5, + ax: Optional[plt.Axes] = None, figsize: Tuple[float, float] = (4, 6), cmap: str = "viridis_r", ofname: Optional[str] = None, @@ -1149,13 +1209,14 @@ def dotplot( ("Adjusted P-value", "P-value", "NOM p-val", "FDR q-val") :param top_term: Number of enriched terms to show (based on values in the `column` (colormap)). :param size: float, scale the dot size to get proper visualization. - :param figsize: tuple, matplotlib figure size. + :param ax: Matplotlib axes. + :param figsize: tuple, matplotlib figure size, only used when `ax` is None. :param cmap: Matplotlib colormap for mapping the `column` semantic. :param ofname: Output file name. If None, don't save figure :param marker: The matplotlib.markers. See https://matplotlib.org/stable/api/markers_api.html :param show_ring bool: Whether to draw outer ring. - :return: matplotlib.Axes. return None if given ofname. + :return: matplotlib.Axes if ofname is None. Only terms with `column` <= `cut-off` are plotted. """ if "group" in kwargs: @@ -1173,6 +1234,7 @@ def dotplot( thresh=cutoff, n_terms=int(top_term), dot_scale=size, + ax=ax, figsize=figsize, cmap=cmap, ofname=ofname, @@ -1242,6 +1304,7 @@ def barplot( title: str = "", cutoff: float = 0.05, top_term: int = 10, + ax: Optional[plt.Axes] = None, figsize: Tuple[float, float] = (4, 6), color: Union[str, List[str], Dict[str, str]] = "salmon", ofname: Optional[str] = None, @@ -1257,7 +1320,8 @@ def barplot( :param cutoff: terms with `column` value < cut-off are shown. Work only for ("Adjusted P-value", "P-value", "NOM p-val", "FDR q-val") :param top_term: number of top enriched terms grouped by `hue` are shown. - :param figsize: tuple, matplotlib figsize. + :param ax: Matplotlib axes. If None, create a new figure. + :param figsize: tuple, matplotlib figsize. only used when ax is None. :param color: color or list or dict of matplotlib.colors. Must be reconigzed by matplotlib. if dict input, dict keys must be found in the `group` :param ofname: output file name. If None, don't save figure @@ -1276,6 +1340,7 @@ def barplot( figsize=figsize, cmap="viridis", # placeholder only ofname=ofname, + ax=ax, ) if isinstance(color, str): color = [color] @@ -1338,8 +1403,10 @@ def __init__( # If working on command line, don't show figure self.fig = Figure(figsize=self.figsize, facecolor="white") self._canvas = FigureCanvas(self.fig) - else: + elif isinstance(ax, plt.Axes): self.fig = ax.figure + else: + raise ValueError("ax must be matplotlib axes or None") # self.fig.suptitle(self.term, fontsize=16, wrap=True, fontweight="bold") def axes_hits(