Skip to content

Commit

Permalink
Merge branch 'master' of github.com:facundoq/tmeasures
Browse files Browse the repository at this point in the history
  • Loading branch information
facundoq committed Dec 18, 2024
2 parents b83dc2a + 21db11d commit 8d141e9
Show file tree
Hide file tree
Showing 9 changed files with 432 additions and 229 deletions.
547 changes: 318 additions & 229 deletions docs/examples/ResNet Invariance with TinyImageNet.ipynb

Large diffs are not rendered by default.

Binary file added docs/examples/ResNet_brightness_invariance.pkl
Binary file not shown.
Binary file added docs/examples/ResNet_rotation_invariance.pkl
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/examples/_layer1_BasicBlock_0_conv2_.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/examples/_layer1_BasicBlock_0_conv2_rgb_.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"tqdm",
"opencv-python",
"scikit-image",
"statsmodels",
],

package_data={
Expand Down
47 changes: 47 additions & 0 deletions tmeasures/visualization/images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import numpy as np
import matplotlib.pyplot as plt

def plot_images_rgb(images,row_context=None,labels=None,cols=None,cmap=None,label_color="blue",label_fontsize=None,dpi=200):
N,C,H,W = images.shape
if cols is None:
cols = np.floor(np.sqrt(N)).astype(int)
rows = (N // cols) + (1 if N % cols >0 else 0)
if label_fontsize is None:
label_fontsize = np.sqrt(cols)*2.5
cols_extra=0 if row_context is None else 1
C_context=cols + cols_extra

f, subplots = plt.subplots(rows,C_context,dpi=dpi,figsize=(C_context,rows))
if cmap is None:
cmap = "gray"
for i in range(rows):
if not row_context is None:
subplots[i,0].imshow(row_context[i])
for j in range(cols):
ax = subplots[i,j+cols_extra]
ax.set_axis_off()
index = i*cols+j
if index <N:
filter_weights = images[index,]
ax.imshow(filter_weights,cmap=cmap,interpolation='nearest')
if not labels is None:
ax.text(0.5, 0.5,labels[index], horizontalalignment='center',verticalalignment='center', transform=ax.transAxes,fontsize=label_fontsize,c=label_color)

def plot_images_multichannel(images,vmin,vmax,colorbar_space=1,labels=None):
N,C,H,W = images.shape
invariance_fontsize = np.sqrt(C)*2
f, subplots = plt.subplots(N,C,dpi=150,figsize=(C+colorbar_space,N))
for i in range(N):
for j in range(C):
filter_weights = images[i,j,:,:]
ax = subplots[i,j]
ax.set_axis_off()
im = ax.imshow(filter_weights,cmap="PuOr",vmin=vmin,vmax=vmax,interpolation='nearest')
if not labels is None:
ax.text(0.5, 0.5,labels[i], horizontalalignment='center',verticalalignment='center', transform=ax.transAxes,fontsize=invariance_fontsize)
right=colorbar_space/(C+colorbar_space)

f.subplots_adjust(right=1-right)
gap = 0.2
cbar_ax = f.add_axes([1-right*(1-gap), 0.15, right*(1-2*gap), 0.7])
f.colorbar(im, cax=cbar_ax)
66 changes: 66 additions & 0 deletions tmeasures/visualization/weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from sklearn.decomposition import NMF
import torch
import numpy as np

from tmeasures.visualization.images import plot_images_multichannel, plot_images_rgb

def reorder_conv2d_weights(activation:torch.nn.Module,invariance:np.array):
with torch.no_grad():
weight = dict(activation.named_parameters())["weight"]
indices = invariance.argsort().copy()
weight[:] = weight[indices,:,:,:]
invariance[:] = invariance[indices]


def sort_weights_invariance(weights:np.array,invariance:np.array,top_k:int=None):
indices = invariance.argsort()
if not top_k is None:
indices = np.concatenate( [indices[:top_k],indices[-top_k:]])
weights = weights[indices,:,:,:]
invariance = invariance[indices]
return weights, invariance


def weights_reduce_nmf(weights:np.array,n_components:int):
Fo,Fi,H,W = weights.shape
model = NMF(n_components=n_components, init='random', random_state=0,max_iter=1000)
nmf_weights = np.zeros((Fo,n_components,H,W))
for i in range(Fo):
input_weights = weights[i,]
flattened_weights = np.abs(input_weights.reshape(Fi,-1))
model.fit(flattened_weights)
nmf_weights[i] = model.components_.reshape(n_components,H,W)
return nmf_weights

def weight_inputs_filter_importance(weights:np.array,max_inputs:int):
Fo,Fi,H,W = weights.shape
input_importance = weights.mean(axis=(2,3))
for i in range(Fo):
indices = np.argsort(input_importance[i,:])[::-1]
weights[i,:,:,:] = weights[i,indices,:,:]
if not max_inputs is None:
weights[:,:max_inputs,]
return weights

def plot_conv2d_filters(conv2d:torch.nn.Module,invariance:np.array,sort=True, top_k=None,max_inputs=10,nmf_components=None):
weights = dict(conv2d.named_parameters())["weight"].detach().numpy()
mi,ma=weights.min(),weights.max()

if sort or not top_k is None :
weights, invariance = sort_weights_invariance(weights,invariance,top_k)
if not nmf_components is None:
weights = weights_reduce_nmf(weights,nmf_components)
if not max_inputs is None:
weights = weight_inputs_filter_importance(weights,max_inputs)
largest = max(abs(mi),abs(ma))
vmin,vmax = -largest,largest
# print(weights.shape)
labels = [f"{i:.02}" for i in invariance]
plot_images_multichannel(weights,vmin,vmax,labels=labels)

def plot_conv2d_filters_rgb(conv2d:torch.nn.Module,invariance:np.array):
weights = dict(conv2d.named_parameters())["weight"].detach().numpy()
weights, invariance = sort_weights_invariance(weights,invariance)
weights = weights_reduce_nmf(weights,3)
labels = [f"{i:.02}" for i in invariance]
plot_images_rgb(weights,labels=labels)

0 comments on commit 8d141e9

Please sign in to comment.