-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'master' of github.com:facundoq/tmeasures
- Loading branch information
Showing
9 changed files
with
432 additions
and
229 deletions.
There are no files selected for viewing
547 changes: 318 additions & 229 deletions
547
docs/examples/ResNet Invariance with TinyImageNet.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
Binary file not shown.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,6 +32,7 @@ | |
"tqdm", | ||
"opencv-python", | ||
"scikit-image", | ||
"statsmodels", | ||
], | ||
|
||
package_data={ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
|
||
def plot_images_rgb(images,row_context=None,labels=None,cols=None,cmap=None,label_color="blue",label_fontsize=None,dpi=200): | ||
N,C,H,W = images.shape | ||
if cols is None: | ||
cols = np.floor(np.sqrt(N)).astype(int) | ||
rows = (N // cols) + (1 if N % cols >0 else 0) | ||
if label_fontsize is None: | ||
label_fontsize = np.sqrt(cols)*2.5 | ||
cols_extra=0 if row_context is None else 1 | ||
C_context=cols + cols_extra | ||
|
||
f, subplots = plt.subplots(rows,C_context,dpi=dpi,figsize=(C_context,rows)) | ||
if cmap is None: | ||
cmap = "gray" | ||
for i in range(rows): | ||
if not row_context is None: | ||
subplots[i,0].imshow(row_context[i]) | ||
for j in range(cols): | ||
ax = subplots[i,j+cols_extra] | ||
ax.set_axis_off() | ||
index = i*cols+j | ||
if index <N: | ||
filter_weights = images[index,] | ||
ax.imshow(filter_weights,cmap=cmap,interpolation='nearest') | ||
if not labels is None: | ||
ax.text(0.5, 0.5,labels[index], horizontalalignment='center',verticalalignment='center', transform=ax.transAxes,fontsize=label_fontsize,c=label_color) | ||
|
||
def plot_images_multichannel(images,vmin,vmax,colorbar_space=1,labels=None): | ||
N,C,H,W = images.shape | ||
invariance_fontsize = np.sqrt(C)*2 | ||
f, subplots = plt.subplots(N,C,dpi=150,figsize=(C+colorbar_space,N)) | ||
for i in range(N): | ||
for j in range(C): | ||
filter_weights = images[i,j,:,:] | ||
ax = subplots[i,j] | ||
ax.set_axis_off() | ||
im = ax.imshow(filter_weights,cmap="PuOr",vmin=vmin,vmax=vmax,interpolation='nearest') | ||
if not labels is None: | ||
ax.text(0.5, 0.5,labels[i], horizontalalignment='center',verticalalignment='center', transform=ax.transAxes,fontsize=invariance_fontsize) | ||
right=colorbar_space/(C+colorbar_space) | ||
|
||
f.subplots_adjust(right=1-right) | ||
gap = 0.2 | ||
cbar_ax = f.add_axes([1-right*(1-gap), 0.15, right*(1-2*gap), 0.7]) | ||
f.colorbar(im, cax=cbar_ax) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
from sklearn.decomposition import NMF | ||
import torch | ||
import numpy as np | ||
|
||
from tmeasures.visualization.images import plot_images_multichannel, plot_images_rgb | ||
|
||
def reorder_conv2d_weights(activation:torch.nn.Module,invariance:np.array): | ||
with torch.no_grad(): | ||
weight = dict(activation.named_parameters())["weight"] | ||
indices = invariance.argsort().copy() | ||
weight[:] = weight[indices,:,:,:] | ||
invariance[:] = invariance[indices] | ||
|
||
|
||
def sort_weights_invariance(weights:np.array,invariance:np.array,top_k:int=None): | ||
indices = invariance.argsort() | ||
if not top_k is None: | ||
indices = np.concatenate( [indices[:top_k],indices[-top_k:]]) | ||
weights = weights[indices,:,:,:] | ||
invariance = invariance[indices] | ||
return weights, invariance | ||
|
||
|
||
def weights_reduce_nmf(weights:np.array,n_components:int): | ||
Fo,Fi,H,W = weights.shape | ||
model = NMF(n_components=n_components, init='random', random_state=0,max_iter=1000) | ||
nmf_weights = np.zeros((Fo,n_components,H,W)) | ||
for i in range(Fo): | ||
input_weights = weights[i,] | ||
flattened_weights = np.abs(input_weights.reshape(Fi,-1)) | ||
model.fit(flattened_weights) | ||
nmf_weights[i] = model.components_.reshape(n_components,H,W) | ||
return nmf_weights | ||
|
||
def weight_inputs_filter_importance(weights:np.array,max_inputs:int): | ||
Fo,Fi,H,W = weights.shape | ||
input_importance = weights.mean(axis=(2,3)) | ||
for i in range(Fo): | ||
indices = np.argsort(input_importance[i,:])[::-1] | ||
weights[i,:,:,:] = weights[i,indices,:,:] | ||
if not max_inputs is None: | ||
weights[:,:max_inputs,] | ||
return weights | ||
|
||
def plot_conv2d_filters(conv2d:torch.nn.Module,invariance:np.array,sort=True, top_k=None,max_inputs=10,nmf_components=None): | ||
weights = dict(conv2d.named_parameters())["weight"].detach().numpy() | ||
mi,ma=weights.min(),weights.max() | ||
|
||
if sort or not top_k is None : | ||
weights, invariance = sort_weights_invariance(weights,invariance,top_k) | ||
if not nmf_components is None: | ||
weights = weights_reduce_nmf(weights,nmf_components) | ||
if not max_inputs is None: | ||
weights = weight_inputs_filter_importance(weights,max_inputs) | ||
largest = max(abs(mi),abs(ma)) | ||
vmin,vmax = -largest,largest | ||
# print(weights.shape) | ||
labels = [f"{i:.02}" for i in invariance] | ||
plot_images_multichannel(weights,vmin,vmax,labels=labels) | ||
|
||
def plot_conv2d_filters_rgb(conv2d:torch.nn.Module,invariance:np.array): | ||
weights = dict(conv2d.named_parameters())["weight"].detach().numpy() | ||
weights, invariance = sort_weights_invariance(weights,invariance) | ||
weights = weights_reduce_nmf(weights,3) | ||
labels = [f"{i:.02}" for i in invariance] | ||
plot_images_rgb(weights,labels=labels) |