Skip to content

Commit

Permalink
add ax keyword, #285
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhuoqing Fang authored and Zhuoqing Fang committed Nov 13, 2024
1 parent 190b263 commit fd916cc
Showing 1 changed file with 87 additions and 20 deletions.
107 changes: 87 additions & 20 deletions gseapy/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,15 @@ def __init__(
xticklabels: bool = True,
yticklabels: bool = True,
ofname: Optional[str] = None,
ax: Optional[plt.Axes] = None,
**kwargs,
):
self.title = "" if title is None else title
self.figsize = figsize
self.xticklabels = xticklabels
self.yticklabels = yticklabels
self.ofname = ofname
self.ax = ax

# scale dataframe
df = df.astype(float)
Expand Down Expand Up @@ -116,9 +118,31 @@ def _auto_ticks(self, ax, labels, axis):
return tickevery

def get_ax(self):
if hasattr(sys, "ps1") and (self.ofname is None):
"""
Return a matplotlib axes object.
If an axes is already set, return it. Otherwise, create a new
figure and axes instance and set it as the current axes.
Parameters
----------
None
Returns
-------
ax : matplotlib axes
The current axes object.
"""
if (self.ax is not None) and isinstance(self.ax, plt.Axes):
self.fig = self.ax.figure
return self.ax
elif hasattr(sys, "ps1") and (self.ofname is None):
# Working in an interactive environment, create a figure
# and show it.
fig = plt.figure(figsize=self.figsize)
else:
# Working non-interactively, create a figure but don't show
# it.
fig = Figure(figsize=self.figsize)
canvas = FigureCanvas(fig)
ax = fig.add_subplot(111)
Expand Down Expand Up @@ -184,6 +208,7 @@ def heatmap(
xticklabels: bool = True,
yticklabels: bool = True,
ofname: Optional[str] = None,
ax: Optional[plt.Axes] = None,
**kwargs,
):
"""Visualize the dataframe.
Expand All @@ -195,11 +220,25 @@ def heatmap(
:param cmap: matplotlib colormap. e.g. "RdBu_r".
:param xticklabels: bool, whether to show xticklabels.
:param xticklabels: bool, whether to show xticklabels.
:param ofname: output file name. If None, don't save figure
:param ofname: output file name. If None, don't save figure.
:param ax: matplotlib axes. Default: None.
:return: ax if ofname is None.
"""

ht = Heatmap(df, z_score, title, figsize, cmap, xticklabels, yticklabels, ofname)
ht = Heatmap(
df=df,
z_score=z_score,
title=title,
figsize=figsize,
cmap=cmap,
xticklabels=xticklabels,
yticklabels=yticklabels,
ofname=ofname,
ax=ax,
**kwargs,
)
ax = ht.draw()
if ofname is None:
return ax
Expand Down Expand Up @@ -237,7 +276,8 @@ def __init__(
:param pheno_pos: phenotype label, positive correlated.
:param pheno_neg: phenotype label, negative correlated.
:param figsize: matplotlib figsize.
:param ofname: output file name. If None, don't save figure
:param ofname: output file name. If None, don't save figure.
:param ax: matplotlib axes. Default: None.
"""
# dataFrame of ranked matrix scores
self.color = "#88C544" if color is None else color
Expand Down Expand Up @@ -284,9 +324,10 @@ def __init__(
# If working on command line, don't show figure
self.fig = Figure(figsize=self.figsize, facecolor="white")
self._canvas = FigureCanvas(self.fig)
else:
elif isinstance(ax, plt.Axes):
self.fig = ax.figure

else:
raise ValueError("ax must be matplotlib axes or None")
self.fig.suptitle(self.term, fontsize=16, wrap=True, fontweight="bold")

def axes_rank(self, rect):
Expand Down Expand Up @@ -598,9 +639,11 @@ def __init__(
thresh: float = 0.05,
n_terms: int = 10,
title: str = "",
ax: Optional[plt.Axes] = None,
figsize: Tuple[float, float] = (6, 5.5),
cmap: str = "viridis_r",
ofname: Optional[str] = None,
marker: str = "o",
**kwargs,
):
"""Visualize GSEApy Results with categorical scatterplot
Expand All @@ -625,14 +668,13 @@ def __init__(
("Adjusted P-value", "P-value", "NOM p-val", "FDR q-val")
:param n_terms: Number of enriched terms to show.
:param dot_scale: float, scale the dot size to get proper visualization.
:param figsize: tuple, matplotlib figure size.
:param ax: Matplotlib axes. Default: None.
:param figsize: tuple, matplotlib figure size, only used when `ax` is None.
:param cmap: Matplotlib colormap for mapping the `column` semantic.
:param ofname: Output file name. If None, don't save figure
:param marker: The matplotlib.markers. See https://matplotlib.org/stable/api/markers_api.html
"""
self.marker = "o"
if "marker" in kwargs:
self.marker = kwargs["marker"]
self.marker = marker
self.y = y
self.x = x
self.x_order = x_order
Expand All @@ -642,6 +684,7 @@ def __init__(
self.figsize = figsize
self.cmap = cmap
self.ofname = ofname
self.ax = ax
self.scale = dot_scale
self.title = title
self.n_terms = n_terms
Expand Down Expand Up @@ -773,16 +816,32 @@ def get_y_order(

def get_ax(self):
"""
setup figure axes
Return a matplotlib axes object.
If an axes is already set, return it. Otherwise, create a new
figure and axes instance and set it as the current axes.
Parameters
----------
None
Returns
-------
ax : matplotlib axes
The current axes object.
"""
# create fig
if hasattr(sys, "ps1") and (self.ofname is None):
# working inside python console, show figure
if (self.ax is not None) and isinstance(self.ax, plt.Axes):
self.fig = self.ax.figure
return self.ax
elif hasattr(sys, "ps1") and (self.ofname is None):
# Working in an interactive environment, create a figure
# and show it.
fig = plt.figure(figsize=self.figsize)
else:
# If working on commandline, don't show figure
# Working non-interactively, create a figure but don't show
# it.
fig = Figure(figsize=self.figsize)
_canvas = FigureCanvas(fig)
canvas = FigureCanvas(fig)
ax = fig.add_subplot(111)
self.fig = fig
return ax
Expand Down Expand Up @@ -1119,6 +1178,7 @@ def dotplot(
cutoff: float = 0.05,
top_term: int = 10,
size: float = 5,
ax: Optional[plt.Axes] = None,
figsize: Tuple[float, float] = (4, 6),
cmap: str = "viridis_r",
ofname: Optional[str] = None,
Expand Down Expand Up @@ -1149,13 +1209,14 @@ def dotplot(
("Adjusted P-value", "P-value", "NOM p-val", "FDR q-val")
:param top_term: Number of enriched terms to show (based on values in the `column` (colormap)).
:param size: float, scale the dot size to get proper visualization.
:param figsize: tuple, matplotlib figure size.
:param ax: Matplotlib axes.
:param figsize: tuple, matplotlib figure size, only used when `ax` is None.
:param cmap: Matplotlib colormap for mapping the `column` semantic.
:param ofname: Output file name. If None, don't save figure
:param marker: The matplotlib.markers. See https://matplotlib.org/stable/api/markers_api.html
:param show_ring bool: Whether to draw outer ring.
:return: matplotlib.Axes. return None if given ofname.
:return: matplotlib.Axes if ofname is None.
Only terms with `column` <= `cut-off` are plotted.
"""
if "group" in kwargs:
Expand All @@ -1173,6 +1234,7 @@ def dotplot(
thresh=cutoff,
n_terms=int(top_term),
dot_scale=size,
ax=ax,
figsize=figsize,
cmap=cmap,
ofname=ofname,
Expand Down Expand Up @@ -1242,6 +1304,7 @@ def barplot(
title: str = "",
cutoff: float = 0.05,
top_term: int = 10,
ax: Optional[plt.Axes] = None,
figsize: Tuple[float, float] = (4, 6),
color: Union[str, List[str], Dict[str, str]] = "salmon",
ofname: Optional[str] = None,
Expand All @@ -1257,7 +1320,8 @@ def barplot(
:param cutoff: terms with `column` value < cut-off are shown. Work only for
("Adjusted P-value", "P-value", "NOM p-val", "FDR q-val")
:param top_term: number of top enriched terms grouped by `hue` are shown.
:param figsize: tuple, matplotlib figsize.
:param ax: Matplotlib axes. If None, create a new figure.
:param figsize: tuple, matplotlib figsize. only used when ax is None.
:param color: color or list or dict of matplotlib.colors. Must be reconigzed by matplotlib.
if dict input, dict keys must be found in the `group`
:param ofname: output file name. If None, don't save figure
Expand All @@ -1276,6 +1340,7 @@ def barplot(
figsize=figsize,
cmap="viridis", # placeholder only
ofname=ofname,
ax=ax,
)
if isinstance(color, str):
color = [color]
Expand Down Expand Up @@ -1338,8 +1403,10 @@ def __init__(
# If working on command line, don't show figure
self.fig = Figure(figsize=self.figsize, facecolor="white")
self._canvas = FigureCanvas(self.fig)
else:
elif isinstance(ax, plt.Axes):
self.fig = ax.figure
else:
raise ValueError("ax must be matplotlib axes or None")
# self.fig.suptitle(self.term, fontsize=16, wrap=True, fontweight="bold")

def axes_hits(
Expand Down

0 comments on commit fd916cc

Please sign in to comment.