Skip to content

Commit

Permalink
Significance:
Browse files Browse the repository at this point in the history
[x] Contour plot update related to isolines
[x] Extension of isoline changes to other functions
[x] Implementation of contour bands
[x] Implementation of dot plots
  • Loading branch information
nikhilbhavikatti authored and hageldave committed Jun 25, 2024
1 parent b3fd4ec commit 79f5d02
Show file tree
Hide file tree
Showing 4 changed files with 322 additions and 37 deletions.
10 changes: 7 additions & 3 deletions example.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import uadapy as ua
import uadapy.data as data
import uadapy.dr.uamds as uamds
import uadapy.plotting.plots2D as plots2D
import uadapy.plotting.plots1D as plots1D
import uadapy.plotting.plots2D as plots2D
import uadapy.plotting.plotsND as plotsND
import numpy as np
import matplotlib.pyplot as plt

Expand All @@ -11,7 +12,10 @@
def example_uamds():
distribs_hi = data.load_iris_normal()
distribs_lo = uamds.uamds(distribs_hi, dims=2)
plots2D.plot_contour(distribs_lo)
plots2D.plot_contour(distribs_lo, 10000, 128, None, [5, 25, 55, 75, 95])
#plots2D.plot_contour_bands(distribs_lo, 10000, 128, None, [5, 25, 55, 75, 95])
#plotsND.plot_contour(distribs_lo, 10000, 128, None, [5, 25, 50, 75, 95])
#plotsND.plot_contour_samples(distribs_lo, 10000, 128, None, [5, 25, 50, 75, 95])

def example_kde():
samples = np.random.randn(1000,2)
Expand All @@ -25,7 +29,7 @@ def example_uamds_1d():
titles = ['sepal length','sepal width','petal length','petal width']
colors = ['red','green', 'blue']
fig, axs = plt.subplots(2, 2, figsize=(12, 8))
fig, axs = plots1D.plot_1d_distribution(distribs_lo, 10000, ['violinplot','stripplot'], 444, fig, axs, labels, titles, colors, vert=True, colorblind_safe=False)
fig, axs = plots1D.plot_1d_distribution(distribs_lo, 10000, ['boxplot','violinplot'], 444, fig, axs, labels, titles, colors, vert=True, colorblind_safe=False)
fig.tight_layout()
plt.show()

Expand Down
82 changes: 68 additions & 14 deletions uadapy/plotting/plots1D.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,19 @@
from math import ceil, sqrt
import glasbey as gb
import seaborn as sns
from matplotlib.patches import Ellipse


def calculate_freedman_diaconis_bins(data):
q25, q75 = np.percentile(data, [25, 75])
iqr = q75 - q25
bin_width = 2 * iqr / np.cbrt(len(data))
num_bins = int((np.max(data) - np.min(data)) / bin_width)
return num_bins

def calculate_offsets(count, max_count):
occupancy = (count/max_count)
return np.linspace(-0.45 * occupancy, 0.45 * occupancy, count)

def calculate_dot_size(num_samples, scale_factor):
if num_samples < 100:
Expand Down Expand Up @@ -108,7 +121,7 @@ def plot_1d_distribution(distributions, num_samples, plot_types:list, seed=55, f
num_samples : int
Number of samples per distribution.
plot_types : list
List of plot types to plot. Valid values are 'boxplot','violinplot', 'stripplot' and 'swarmplot'.
List of plot types to plot. Valid values are 'boxplot','violinplot', 'stripplot', 'swarmplot' and 'dotplot'.
seed : int
Seed for the random number generator for reproducibility. It defaults to 55 if not provided.
fig : matplotlib.figure.Figure or None, optional
Expand Down Expand Up @@ -144,11 +157,17 @@ def plot_1d_distribution(distributions, num_samples, plot_types:list, seed=55, f
fig, axs, samples, palette, num_plots, num_cols = setup_plot(distributions, num_samples, seed, fig, axs, colors, **kwargs)

num_attributes = np.shape(samples)[2]
if labels:
ticks = range(len(labels))
else:
ticks = range(len(samples))

for i, ax_row in enumerate(axs):
for j, ax in enumerate(ax_row):
index = i * num_cols + j
if index < num_plots and index < num_attributes:
y_min = 9999
y_max = -9999
for k, sample in enumerate(samples):
if 'boxplot' in plot_types:
boxprops = dict(facecolor=palette[k % len(palette)], edgecolor='black')
Expand All @@ -168,31 +187,66 @@ def plot_1d_distribution(distributions, num_samples, plot_types:list, seed=55, f
parts['cmaxes'].remove()
parts['cmins'].remove()
parts['cmeans'].set_edgecolor('black')
if 'stripplot' in plot_types or 'swarmplot' in plot_types:
if 'stripplot' in plot_types or 'swarmplot' in plot_types or 'dotplot' in plot_types:
if 'dot_size' in kwargs:
dot_size = kwargs['dot_size']
else:
if 'stripplot' in plot_types:
scale_factor = 1 + 0.5 * np.log10(num_samples/100)
else :
scale_factor = 1
dot_size = calculate_dot_size(len(sample[:,index]), scale_factor)
if 'stripplot' in plot_types:
if 'dot_size' not in kwargs:
scale_factor = 1 + np.log10(num_samples/100)
dot_size = calculate_dot_size(len(sample[:,index]), scale_factor)
if kwargs.get('vert',True):
sns.stripplot(x=[k]*len(sample[:,index]), y=sample[:,index], color='black', size=dot_size * 1.5, jitter=0.25, ax=ax)
sns.stripplot(x=[k]*len(sample[:,index]), y=sample[:,index], color=palette[k % len(palette)], size=dot_size, jitter=0.25, ax=ax)
else:
sns.stripplot(x=sample[:,index], y=[k]*len(sample[:,index]), color='black', size=dot_size * 1.5, jitter=0.25, ax=ax, orient='h')
sns.stripplot(x=sample[:,index], y=[k]*len(sample[:,index]), color=palette[k % len(palette)], size=dot_size, jitter=0.25, ax=ax, orient='h')
if 'swarmplot' in plot_types:
if 'dot_size' not in kwargs:
dot_size = calculate_dot_size(len(sample[:,index]), 1)
if kwargs.get('vert',True):
sns.swarmplot(x=[k]*len(sample[:,index]), y=sample[:,index], color='black', size=dot_size, ax=ax)
sns.swarmplot(x=[k]*len(sample[:,index]), y=sample[:,index], color=palette[k % len(palette)], size=dot_size, ax=ax)
else:
sns.swarmplot(x=sample[:,index], y=[k]*len(sample[:,index]), color='black', size=dot_size, ax=ax, orient='h')
sns.swarmplot(x=sample[:,index], y=[k]*len(sample[:,index]), color=palette[k % len(palette)], size=dot_size, ax=ax, orient='h')
if 'dotplot' in plot_types:
if 'dot_size' not in kwargs:
dot_size = calculate_dot_size(len(sample[:,index]), 0.005)
flat_sample = np.ravel(sample[:,index])
ticks = [x + 0.5 for x in range(len(samples))]
if y_min > np.min(flat_sample):
y_min = np.min(flat_sample)
if y_max < np.max(flat_sample):
y_max = np.max(flat_sample)
num_bins = calculate_freedman_diaconis_bins(flat_sample)
bin_width = kwargs.get('bin_width', (np.max(flat_sample) - np.min(flat_sample)) / num_bins)
bins = np.arange(np.min(flat_sample), np.max(flat_sample) + bin_width, bin_width)
binned_data, bin_edges = np.histogram(flat_sample, bins=bins)
max_count = np.max(binned_data)
for bin_idx in range(len(binned_data)):
count = binned_data[bin_idx]
if count > 0:
bin_center = (bin_edges[bin_idx] + bin_edges[bin_idx + 1]) / 2
# Calculate symmetrical offsets
if count == 1:
offsets = [0] # Single dot in the center
else:
offsets = calculate_offsets(count, max_count)
for offset in offsets:
if kwargs.get('vert',True):
ellipse = Ellipse((ticks[k] + offset, bin_center), width=dot_size, height=dot_size, color=palette[k % len(palette)])
else:
ellipse = Ellipse((bin_center, ticks[k] + offset), width=dot_size, height=dot_size, color=palette[k % len(palette)])
ax.add_patch(ellipse)
if 'dotplot' in plot_types:
if kwargs.get('vert',True):
ax.set_xlim(0, len(samples))
ax.set_ylim(y_min - 1, y_max + 1)
else:
ax.set_xlim(y_min - 1, y_max + 1)
ax.set_ylim(0, len(samples))
if labels:
if kwargs.get('vert', True):
ax.set_xticks(range(len(labels)))
ax.set_xticks(ticks)
ax.set_xticklabels(labels, rotation=45, ha='right')
else:
ax.set_yticks(range(len(labels)))
ax.set_yticks(ticks)
ax.set_yticklabels(labels, rotation=45, ha='right')
if titles:
ax.set_title(titles[index] if titles and index < len(titles) else 'Distribution ' + str(index + 1))
Expand Down
144 changes: 142 additions & 2 deletions uadapy/plotting/plots2D.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import matplotlib.pyplot as plt
import numpy as np
import uadapy.distribution as dist
from numpy import ma
from matplotlib import ticker

def plot_samples(distribution, num_samples, **kwargs):
"""
Expand All @@ -24,11 +26,42 @@ def plot_samples(distribution, num_samples, **kwargs):
plt.ylabel(kwargs['ylabel'])
plt.show()

def plot_contour(distributions, num_samples, resolution=128, ranges=None, quantiles:list=None, seed=55, **kwargs):
"""
Plot contour plots for samples drawn from given distributions.
Parameters
----------
distributions : list
List of distributions to plot.
num_samples : int
Number of samples per distribution.
resolution : int, optional
The resolution of the plot. Default is 128.
ranges : list or None, optional
The ranges for the x and y axes. If None, the ranges are calculated based on the distributions.
quantiles : list or None, optional
List of quantiles to use for determining isovalues. If None, the 99.7%, 95%, and 68% quantiles are used.
seed : int
Seed for the random number generator for reproducibility. It defaults to 55 if not provided.
**kwargs : additional keyword arguments
Additional optional plotting arguments.
Returns
-------
None
This function does not return a value. It displays a plot using plt.show().
Raises
------
ValueError
If a quantile is not between 0 and 100 (exclusive), or if a quantile results in an index that is out of bounds.
"""

def plot_contour(distributions, resolution=128, ranges=None, **kwargs):
if isinstance(distributions, dist.distribution):
distributions = [distributions]
contour_colors = generate_spectrum_colors(len(distributions))

if ranges is None:
min_val = np.zeros(distributions[0].mean().shape)+1000
max_val = np.zeros(distributions[0].mean().shape)-1000
Expand All @@ -50,7 +83,114 @@ def plot_contour(distributions, resolution=128, ranges=None, **kwargs):
pdf = d.pdf(coordinates)
pdf = pdf.reshape(xv.shape)
color = contour_colors[i]
plt.contour(xv, yv, pdf, colors = [color])

# Monte Carlo approach for determining isovalues
isovalues = []
samples = d.sample(num_samples, seed)
densities = d.pdf(samples)
densities.sort()
if quantiles is None:
isovalues.append(densities[int((1 - 99.7/100) * num_samples)]) # 99.7% quantile
isovalues.append(densities[int((1 - 95/100) * num_samples)]) # 95% quantile
isovalues.append(densities[int((1 - 68/100) * num_samples)]) # 68% quantile
else:
quantiles.sort(reverse=True)
for quantile in quantiles:
if not 0 < quantile < 100:
raise ValueError(f"Invalid quantile: {quantile}. Quantiles must be between 0 and 100 (exclusive).")
elif int((1 - quantile/100) * num_samples) >= num_samples:
raise ValueError(f"Quantile {quantile} results in an index that is out of bounds.")
isovalues.append(densities[int((1 - quantile/100) * num_samples)])

plt.contour(xv, yv, pdf, levels=isovalues, colors = [color])
plt.show()

def plot_contour_bands(distributions, num_samples, resolution=128, ranges=None, quantiles:list=None, seed=55, **kwargs):
"""
Plot contour bands for samples drawn from given distributions.
Parameters
----------
distributions : list
List of distributions to plot.
num_samples : int
Number of samples per distribution.
resolution : int, optional
The resolution of the plot. Default is 128.
ranges : list or None, optional
The ranges for the x and y axes. If None, the ranges are calculated based on the distributions.
quantiles : list or None, optional
List of quantiles to use for determining isovalues. If None, the 99.7%, 95%, and 68% quantiles are used.
seed : int
Seed for the random number generator for reproducibility. It defaults to 55 if not provided.
**kwargs : additional keyword arguments
Additional optional plotting arguments.
Returns
-------
None
This function does not return a value. It displays a plot using plt.show().
Raises
------
ValueError
If a quantile is not between 0 and 100 (exclusive), or if a quantile results in an index that is out of bounds.
"""

if isinstance(distributions, dist.distribution):
distributions = [distributions]

# Sequential and perceptually uniform colormaps
colormaps = [
'Reds', 'Blues', 'Greens', 'Greys', 'Oranges', 'Purples',
'YlOrBr', 'YlOrRd', 'OrRd', 'PuRd', 'RdPu', 'BuPu',
'GnBu', 'PuBu', 'YlGnBu', 'PuBuGn', 'BuGn', 'YlGn',
'viridis', 'plasma', 'inferno', 'magma', 'cividis'
]

if ranges is None:
min_val = np.zeros(distributions[0].mean().shape)+1000
max_val = np.zeros(distributions[0].mean().shape)-1000
cov_max = np.zeros(distributions[0].mean().shape)
for d in distributions:
min_val=np.min(np.array([d.mean(), min_val]), axis=0)
max_val=np.max(np.array([d.mean(), max_val]), axis=0)
cov_max = np.max(np.array([np.diagonal(d.cov()), cov_max]), axis=0)
cov_max = np.sqrt(cov_max)
ranges = [(mi-3*co, ma+3*co) for mi,ma, co in zip(min_val, max_val, cov_max)]
range_x = ranges[0]
range_y = ranges[1]
for i, d in enumerate(distributions):
x = np.linspace(range_x[0], range_x[1], resolution)
y = np.linspace(range_y[0], range_y[1], resolution)
xv, yv = np.meshgrid(x, y)
coordinates = np.stack((xv, yv), axis=-1)
coordinates = coordinates.reshape((-1, 2))
pdf = d.pdf(coordinates)
pdf = pdf.reshape(xv.shape)
pdf = ma.masked_where(pdf <= 0, pdf) # Mask non-positive values to avoid log scale issues

# Monte Carlo approach for determining isovalues
isovalues = []
samples = d.sample(num_samples, seed)
densities = d.pdf(samples)
densities.sort()
if quantiles is None:
isovalues.append(densities[int((1 - 99.7/100) * num_samples)]) # 99.7% quantile
isovalues.append(densities[int((1 - 95/100) * num_samples)]) # 95% quantile
isovalues.append(densities[int((1 - 68/100) * num_samples)]) # 68% quantile
else:
quantiles.sort(reverse=True)
for quantile in quantiles:
if not 0 < quantile < 100:
raise ValueError(f"Invalid quantile: {quantile}. Quantiles must be between 0 and 100 (exclusive).")
elif int((1 - quantile/100) * num_samples) >= num_samples:
raise ValueError(f"Quantile {quantile} results in an index that is out of bounds.")
isovalues.append(densities[int((1 - quantile/100) * num_samples)])

# Generate logarithmic levels and create the contour plot with different colormap for each distribution
plt.contourf(xv, yv, pdf, levels=isovalues, locator=ticker.LogLocator(), cmap=colormaps[i % len(colormaps)])

plt.show()

# HELPER FUNCTIONS
Expand Down
Loading

0 comments on commit 79f5d02

Please sign in to comment.