Skip to content

Commit

Permalink
Merge pull request #4 from nikhilbhavikatti/dev
Browse files Browse the repository at this point in the history
Merge pull request that contributes several 1D plots for distributions
  • Loading branch information
hageldave authored Jun 2, 2024
2 parents 3b864da + 27ba208 commit b3fd4ec
Show file tree
Hide file tree
Showing 2 changed files with 217 additions and 1 deletion.
14 changes: 13 additions & 1 deletion example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import uadapy.data as data
import uadapy.dr.uamds as uamds
import uadapy.plotting.plots2D as plots2D
import uadapy.plotting.plots1D as plots1D
import numpy as np
import matplotlib.pyplot as plt



Expand All @@ -16,8 +18,18 @@ def example_kde():
distr = ua.distribution.distribution(samples)
plots2D.plot_contour(distr)


def example_uamds_1d():
distribs_hi = data.load_iris_normal()
distribs_lo = uamds.uamds(distribs_hi, dims=4)
labels = ['setosa','versicolor','virginica']
titles = ['sepal length','sepal width','petal length','petal width']
colors = ['red','green', 'blue']
fig, axs = plt.subplots(2, 2, figsize=(12, 8))
fig, axs = plots1D.plot_1d_distribution(distribs_lo, 10000, ['violinplot','stripplot'], 444, fig, axs, labels, titles, colors, vert=True, colorblind_safe=False)
fig.tight_layout()
plt.show()

if __name__ == '__main__':
example_uamds()
# example_kde()
# example_uamds_1d()
204 changes: 204 additions & 0 deletions uadapy/plotting/plots1D.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
import numpy as np
import uadapy.distribution as dist
import matplotlib.pyplot as plt
from math import ceil, sqrt
import glasbey as gb
import seaborn as sns

def calculate_dot_size(num_samples, scale_factor):
if num_samples < 100:
dot_size = 3.125
else:
dot_size = scale_factor * (50 /(4 ** np.log10(num_samples)))
return dot_size

def setup_plot(distributions, num_samples, seed, fig=None, axs=None, colors=None, **kwargs):
"""
Set up the plot for samples drawn from given distributions.
Parameters
----------
distributions : list
List of distributions to plot. If a single distribution is passed, it will be converted into a list.
num_samples : int
Number of samples per distribution.
seed : int
Seed for the random number generator for reproducibility.
fig : matplotlib.figure.Figure or None, optional
Figure object to use for plotting. If None, a new figure will be created.
axs : matplotlib.axes.Axes or array of Axes or None, optional
Axes object(s) to use for plotting. If None, new axes will be created.
colors : list or None, optional
List of colors to use for each distribution. If None, Glasbey colors will be used.
**kwargs : additional keyword arguments
Additional optional arguments.
- colorblind_safe : bool, optional
If True, the plot will use colors suitable for colorblind individuals.
Default is False.
Returns
-------
fig : matplotlib.figure.Figure
The figure object containing the plot.
axs : list
List of Axes objects used for plotting.
samples : list
List of samples drawn from the distributions.
palette : list
List of colors to use for each distribution.
num_plots : int
Number of subplots.
num_cols : int
Number of columns in the subplot layout.
"""

samples = []

if isinstance(distributions, dist.distribution):
distributions = [distributions]

# Calculate the layout of subplots
if axs is None:
num_plots = distributions[0].dim
num_rows = ceil(sqrt(num_plots))
num_cols = ceil(num_plots / num_rows)
fig, axs = plt.subplots(num_rows, num_cols)
else:
# Case 1: axs is a 2D array (multiple rows and columns)
if isinstance(axs, np.ndarray):
dim = axs.shape
if (len(axs.shape) == 1):
num_rows, num_cols = 1, dim[0]
else:
num_rows, num_cols = dim
# Case 2: axs is not an array (single subplot)
else:
num_rows, num_cols = 1, 1
num_plots = num_rows + num_cols

# Ensure axs is a 2D array even if there's only one row or column
if num_rows == 1:
axs = [axs]
if num_cols == 1:
axs = [[ax] for ax in axs]

for d in distributions:
samples.append(d.sample(num_samples, seed))

# Generate Glasbey colors
if colors is None:
palette = gb.create_palette(palette_size=len(samples), colorblind_safe=kwargs.get('colorblind_safe', False))
else:
# If colors are provided but fewer than the number of samples, add more colors from Glasbey palette
if len(colors) < len(samples):
additional_colors = gb.create_palette(palette_size=len(samples) - len(colors), colorblind_safe=kwargs.get('colorblind_safe', True))
colors.extend(additional_colors)
palette = colors

return fig, axs, samples, palette, num_plots, num_cols

def plot_1d_distribution(distributions, num_samples, plot_types:list, seed=55, fig=None, axs=None, labels=None, titles=None, colors=None, **kwargs):
"""
Plot box plots, violin plots and dot plots for samples drawn from given distributions.
Parameters
----------
distributions : list
List of distributions to plot.
num_samples : int
Number of samples per distribution.
plot_types : list
List of plot types to plot. Valid values are 'boxplot','violinplot', 'stripplot' and 'swarmplot'.
seed : int
Seed for the random number generator for reproducibility. It defaults to 55 if not provided.
fig : matplotlib.figure.Figure or None, optional
Figure object to use for plotting. If None, a new figure will be created.
axs : matplotlib.axes.Axes or array of Axes or None, optional
Axes object(s) to use for plotting. If None, new axes will be created.
labels : list or None, optional
Labels for each distribution.
titles : list or None, optional
Titles for each subplot.
colors : list or None, optional
List of colors to use for each distribution. If None, Glasbey colors will be used.
**kwargs : additional keyword arguments
Additional optional plotting arguments.
- vert : bool, optional
If True, boxes will be drawn vertically. If False, boxes will be drawn horizontally.
Default is True.
- colorblind_safe : bool, optional
If True, the plot will use colors suitable for colorblind individuals.
Default is False.
- dot_size : float, optional
This parameter determines the size of the dots used in the 'stripplot' and 'swarmplot'.
If not provided, the size is calculated based on the number of samples and the type of plot.
Returns
-------
matplotlib.figure.Figure
The figure object containing the plot.
list
List of Axes objects used for plotting.
"""

fig, axs, samples, palette, num_plots, num_cols = setup_plot(distributions, num_samples, seed, fig, axs, colors, **kwargs)

num_attributes = np.shape(samples)[2]

for i, ax_row in enumerate(axs):
for j, ax in enumerate(ax_row):
index = i * num_cols + j
if index < num_plots and index < num_attributes:
for k, sample in enumerate(samples):
if 'boxplot' in plot_types:
boxprops = dict(facecolor=palette[k % len(palette)], edgecolor='black')
whiskerprops = dict(color='black', linestyle='--')
capprops = dict(color='black')
ax.boxplot(sample[:, index], positions=[k], patch_artist=True, boxprops=boxprops,
showfliers=False, whiskerprops=whiskerprops, capprops=capprops,
showmeans=True, meanline=True, meanprops=dict(color="black", linestyle='-'),
medianprops=dict(linewidth=0), vert=kwargs.get('vert', True))
if 'violinplot' in plot_types:
parts = ax.violinplot(sample[:,index], positions=[k], showmeans=True, vert=kwargs.get('vert', True))
for pc in parts['bodies']:
pc.set_facecolor(palette[k % len(palette)])
pc.set_edgecolor(palette[k % len(palette)])
pc.set_alpha(0.5)
parts['cbars'].remove()
parts['cmaxes'].remove()
parts['cmins'].remove()
parts['cmeans'].set_edgecolor('black')
if 'stripplot' in plot_types or 'swarmplot' in plot_types:
if 'dot_size' in kwargs:
dot_size = kwargs['dot_size']
else:
if 'stripplot' in plot_types:
scale_factor = 1 + 0.5 * np.log10(num_samples/100)
else :
scale_factor = 1
dot_size = calculate_dot_size(len(sample[:,index]), scale_factor)
if 'stripplot' in plot_types:
if kwargs.get('vert',True):
sns.stripplot(x=[k]*len(sample[:,index]), y=sample[:,index], color='black', size=dot_size * 1.5, jitter=0.25, ax=ax)
else:
sns.stripplot(x=sample[:,index], y=[k]*len(sample[:,index]), color='black', size=dot_size * 1.5, jitter=0.25, ax=ax, orient='h')
if 'swarmplot' in plot_types:
if kwargs.get('vert',True):
sns.swarmplot(x=[k]*len(sample[:,index]), y=sample[:,index], color='black', size=dot_size, ax=ax)
else:
sns.swarmplot(x=sample[:,index], y=[k]*len(sample[:,index]), color='black', size=dot_size, ax=ax, orient='h')
if labels:
if kwargs.get('vert', True):
ax.set_xticks(range(len(labels)))
ax.set_xticklabels(labels, rotation=45, ha='right')
else:
ax.set_yticks(range(len(labels)))
ax.set_yticklabels(labels, rotation=45, ha='right')
if titles:
ax.set_title(titles[index] if titles and index < len(titles) else 'Distribution ' + str(index + 1))
ax.yaxis.set_ticks_position('none')
ax.grid(True, linestyle=':', linewidth='0.5', color='gray')
else:
ax.set_visible(False) # Hide unused subplots

return fig, axs

0 comments on commit b3fd4ec

Please sign in to comment.