Skip to content

Commit

Permalink
moving to pairplots for correlation shiny
Browse files Browse the repository at this point in the history
  • Loading branch information
arik-shurygin committed Oct 10, 2024
1 parent 155c27b commit de4d16d
Showing 1 changed file with 49 additions and 82 deletions.
131 changes: 49 additions & 82 deletions shiny_visualizers/shiny_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,24 +474,28 @@ def load_checkpoint_inference_correlation_pairs(
posteriors: dict[str, list] = flatten_list_parameters(posteriors)
# drop any final_timestep parameters in case they snuck in
posteriors = drop_keys_with_substring(posteriors, "final_timestep")
# pick first key, get the samples for that key, get the shape of that np.ndarray
number_of_samples = posteriors[list(posteriors.keys())[0]].shape[1]
# if we are dealing with many samples per chain, narrow down to 100 samples per chain
if number_of_samples > 250:
selected_indices = np.random.choice(
number_of_samples, size=100, replace=False
)
posteriors = {
key: matrix[:, selected_indices]
for key, matrix in posteriors.items()
}
number_of_samples = posteriors[list(posteriors.keys())[0]].shape[1]
# Flatten matrices including chains and create Correlation DataFrame
posteriors = {
key: np.array(matrix).flatten() for key, matrix in posteriors.items()
}
# pick first key, get the samples for that key, get the shape of that np.ndarray
number_of_samples = posteriors[list(posteriors.keys())[0]].shape[0]
print(number_of_samples)
# if number_of_samples > 1000:
# selected_indices = np.random.choice(
# number_of_samples, size=100, replace=False
# )
# posteriors = {
# key: posteriors[key][selected_indices] for key in posteriors.keys()
# }
columns = posteriors.keys()
num_cols = len(list(columns))
label_size = max(2, min(10, 200 / num_cols))
# Compute the correlation matrix, reverse it so diagonal starts @ top left
correlation_matrix = pd.DataFrame(posteriors) # .corr()[::-1]
samples_df = pd.DataFrame(posteriors)
# correlation_matrix = samples_df.corr()
cmap = LinearSegmentedColormap.from_list("", ["red", "grey", "blue"])

def reg_coef(x, y, label=None, color=None, **kwargs):
Expand All @@ -502,100 +506,63 @@ def reg_coef(x, y, label=None, color=None, **kwargs):
xy=(0.5, 0.5),
xycoords="axes fraction",
ha="center",
# vary size and color by the magnitude of correlation
color=cmap(r),
size=label_size * abs(r) + label_size,
)
# ax.texts[0].set_size(16)
# ax.set_axis_off()
ax.set_axis_off()

def reg_plot_custom(x, y, label=None, color=None, **kwargs):
ax = plt.gca()
r, p = pearsonr(x, y)
ax = sns.regplot(
x=x,
y=y,
ax=ax,
fit_reg=True,
scatter_kws={"alpha": 0.2, "s": 0.5},
line_kws={"color": cmap(r), "linewidth": 1},
)

# Create the plot
# fig, ax = plt.subplots(figsize=(num_cols + 1, num_cols + 1))
g = sns.PairGrid(
data=correlation_matrix, vars=columns, height=5, diag_sharey=False
data=samples_df,
vars=columns,
height=5,
diag_sharey=False,
layout_pad=0.01,
)
# g.figure.set_size_inches((num_cols + 1, num_cols + 1))
# g.figure.tight_layout()
g.map_upper(reg_coef)
g = g.map_lower(
sns.regplot, fit_reg=False # ,scatter_kws={"edgecolor": "white"}
reg_plot_custom, # fit_reg=True # scatter_kws={"edgecolor": "white"}
)
g = g.map_diag(sns.histplot, kde=True)
label_size = max(2, min(10, 75 / num_cols))
print(label_size)
for ax in g.axes.flatten():
plt.setp(ax.get_xticklabels(), rotation=45)
plt.setp(ax.get_xticklabels(), rotation=45, size=label_size)
plt.setp(ax.get_yticklabels(), rotation=45, size=label_size)
# extract the existing xaxis label
xlabel = ax.get_xlabel()
# set the xaxis label with rotation
ax.set_xlabel(xlabel, size=label_size)
ax.set_xlabel(xlabel, size=label_size, rotation=90, labelpad=4.0)

ylabel = ax.get_ylabel()
ax.set_ylabel(ylabel, size=label_size)
ax.set_ylabel(ylabel, size=label_size, rotation=0, labelpad=15.0)
ax.label_outer(remove_inner_ticks=True)
# Adjust layout to make sure everything fits
g.figure.tight_layout()
px = 1 / plt.rcParams["figure.dpi"]
g.figure.set_size_inches((1600 * px, 1600 * px))
g.figure.tight_layout(pad=0.01, h_pad=0.01, w_pad=0.01)

# Adjust spacing if necessary (values are fractions of figure size)
g.figure.subplots_adjust(left=0.05, right=0.95, top=0.95, bottom=0.05)
return g.figure

# # Create subplots: Each subplot is one pair of 'columns'
# fig = make_subplots(
# rows=len(columns),
# cols=len(columns),
# shared_xaxes=True,
# shared_yaxes=True,
# vertical_spacing=0.03,
# horizontal_spacing=0.03,
# g.figure.subplots_adjust(
# left=0.08, right=0.92, top=0.92, bottom=0.08, wspace=0.1, hspace=0.1
# )
# print(correlation_matrix)

# for i, col1 in enumerate(columns):
# for j, col2 in enumerate(columns):
# if i == j:
# # Diagonal - we place histogram or kde here
# fig.add_trace(
# go.Histogram(x=posteriors[col1], nbinsx=20),
# row=i + 1,
# col=j + 1,
# )
# elif i < j:
# # Upper triangle - show Pearson correlation coefficient
# corr_value = correlation_matrix[col1][col2]
# fig.add_annotation(
# xref="x domain",
# yref="y domain",
# x=0.5,
# y=0.5,
# xanchor="center",
# yanchor="middle",
# text=f"ρ = {corr_value:.2f}",
# showarrow=False,
# font=dict(size=10),
# row=i + 1,
# col=j + 1,
# )
# else:
# # Lower triangle - scatter plots here
# fig.add_trace(
# go.Scatter(
# x=posteriors[col2],
# y=posteriors[col1],
# mode="markers",
# marker=dict(size=3),
# ),
# row=i + 1,
# col=j + 1,
# )

# # Update layout and axes properties if needed
# fig.update_layout(
# height=overview_subplot_size,
# width=overview_subplot_size,
# title_text="Pairplot with Plotly",
# )
# # fig.update_traces(
# # diagonal_visible=False
# # ) # Hide diagonal subplots for clarity
# return fig
g.figure.subplots_adjust(wspace=0.1, hspace=0.1)
g.figure.savefig("testing_corr2.png")
return g.figure


def _generate_row_wise_legends(fig, num_cols):
Expand Down

0 comments on commit de4d16d

Please sign in to comment.