-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
217 lines (159 loc) · 8.18 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
import torch
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from lrp import *
# Assuming MNISTDataLoader, LeNet5 are properly defined in their respective modules
from data_loader import MNISTDataLoader
from lenet_model import LeNet5
from model_utils import ModelUtils
import torch
import matplotlib.pyplot as plt
import numpy as np
# from utils import manual_convolution
def visualize_relevance_scores(model, input_data, filename="relevabce_score.png"):
# if relevance_scores.size(1) != 16:
# raise ValueError(f"Expected 16 feature maps, got {relevance_scores.size(1)}")
# Apply LRP first to get the relevance scores
relevance_scores = apply_lrp_to_last_conv_layer(model, input_data)
# Assuming relevance_scores have the correct shape [batch_size, num_filters, height, width]
num_filters = relevance_scores.size(1)
fig, axs = plt.subplots(nrows=4, ncols=4, figsize=(12, 12)) # Assuming you want to visualize 16 filters
for i, ax in enumerate(axs.flat):
if i < num_filters:
heatmap = relevance_scores[0, i].detach().cpu().numpy() # Get the first batch's i-th filter
im = ax.imshow(heatmap, cmap='hot', interpolation='nearest')
ax.set_title(f'Feature Map {i+1}')
ax.axis('off')
plt.colorbar(im, ax=axs.ravel().tolist(), orientation='horizontal')
plt.suptitle('Relevance Scores for 16 Feature Maps in the Last Conv Layer')
plt.savefig(filename)
plt.close() # Close the figure to free up memory
print(f"Plot saved as {filename}")
def aggregate_and_plot_relevance(relevance_scores, filename='aggregated_relevance_scores.png', bar_color='blue'):
"""
Aggregates and plots the relevance scores for each feature map from a convolutional layer.
Parameters:
- relevance_scores (torch.Tensor): A tensor of relevance scores with shape (batch_size, num_feature_maps, height, width).
- filename (str): Filename for saving the plot.
- bar_color (str): Color of the bars in the plot.
The function sums the relevance scores across the spatial dimensions of each feature map and plots a bar graph.
"""
if relevance_scores.dim() != 4:
raise ValueError("Expected relevance_scores to have 4 dimensions (batch_size, num_feature_maps, height, width)")
# Verify and display the shape of the relevance scores
print("Shape of relevance_scores:", relevance_scores.shape)
# Sum over spatial dimensions to aggregate relevance scores for each feature map
aggregated_scores = relevance_scores.sum(dim=[2, 3]) # [batch_size, num_feature_maps]
# Ensure using only the first batch for visualization
aggregated_scores = aggregated_scores[0].detach().cpu().numpy() # Convert to numpy array for plotting
# Creating the bar chart
plt.figure(figsize=(10, 6))
feature_map_indices = range(1, aggregated_scores.size + 1)
plt.bar(feature_map_indices, aggregated_scores, color=bar_color)
plt.xlabel('Feature Map')
plt.ylabel('Aggregated Relevance Score')
plt.title('Aggregated Relevance Scores for Feature Maps')
plt.savefig(filename)
plt.close() # Close the figure to free up memory
print(f"Plot saved as {filename}")
# Example usage:
# Assuming `relevance_scores` is correctly derived from the last convolutional layer output
# relevance_scores = apply_lrp_to_last_conv_layer(model, input_data) # Make sure this is correctly implemented
# aggregate_and_plot_relevance(relevance_scores)
def get_layers(self):
"""Return a list of layers in the model for relevance propagation."""
return [self.conv1, self.conv2, self.fc1, self.fc2, self.fc3]
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize data loader
data_loader = MNISTDataLoader()
train_loader, test_loader = data_loader.load_data()
import matplotlib.pyplot as plt
def visualize_mnist_samples(data_loader):
# Get one batch of data
images, labels = next(iter(data_loader))
images = images.numpy() # Convert images to numpy arrays for visualization
fig, axes = plt.subplots(1, 10, figsize=(10, 2)) # Create a row of 10 subplots
for idx, ax in enumerate(axes):
ax.imshow(images[idx][0], cmap='gray') # Display the first 10 images
ax.set_title(labels[idx].item())
ax.axis('off')
plt.show()
# Assuming `data_loader` is defined as an instance of MNISTDataLoader and initialized
data_loader_instance = MNISTDataLoader()
train_loader, test_loader = data_loader_instance.load_data()
# Visualize some samples from the training set
visualize_mnist_samples(train_loader)
# Add this method to your LeNet5 model if not already present.
LeNet5.get_layers = get_layers
# Initialize and train model
model = LeNet5()
print(model.last_conv_layer) # Should not be None
print(hasattr(model.last_conv_layer, 'relprop')) # Should be True
model_utils = ModelUtils(model, device)
model_utils.train(train_loader)
model_utils.evaluate(test_loader)
# Select a subset of data for LRP analysis
images, _ = next(iter(test_loader))
images = images.to(device)
# input_data = next(iter(test_loader))[0].to(device) # Get one batch of test data
print("Output before relevance propagation:", model.last_conv_output.shape)
relevance_scores = apply_lrp_to_last_conv_layer(model, images)
print("Relevance scores shape:", relevance_scores.shape)
# Apply LRP to analyze the selected subset
relevance_scores = apply_lrp_to_last_conv_layer(model, images)
print("Shape of relevance_scores:", relevance_scores.shape)
# Now, `relevance_scores` contains the relevance scores for the last convolutional layer's feature maps
# You might want to visualize these scores to understand which parts of the input image were most
# influential in the model's predictions. The exact method of visualization will depend on your
# specific requirements and the shape of `relevance_scores`.
# Assuming relevance_scores is the tensor you're trying to visualize
# Sum the relevance scores across channels (dim=0 for batch, dim=1 for channels)
relevance_summed = relevance_scores[0].sum(dim=0).detach().cpu().numpy()
plt.imshow(relevance_summed, cmap='hot', interpolation='nearest')
plt.colorbar()
plt.title("Relevance Scores Heatmap")
plt.savefig('relevance_scores_heatmap.png') # Saves the plot as an image file
plt.close() # Closes the plot to free up resources, especially useful if generating many plots in a loop
relevance_scores_flattened = relevance_summed.flatten()
plt.hist(relevance_scores_flattened, bins=50, color='blue', alpha=0.7)
plt.title('Histogram of Relevance Scores')
plt.xlabel('Relevance Score')
plt.ylabel('Frequency')
plt.grid(True)
plt.savefig('relevance_histogram.png') # Save the histogram
plt.close()
epochs = range(1, len(model_utils.train_losses) + 1)
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(epochs, model_utils.train_losses, 'b-', label='Loss')
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(epochs, model_utils.train_accuracies, 'r-', label='Accuracy')
plt.title('Training Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.savefig('training_metrics.png') # Save the figure to a file
plt.close() # Close the plot
# Visualize relevance scores
visualize_relevance_scores(model, images)
# agregated relavance plot
aggregate_and_plot_relevance(relevance_scores)
def list_all_layers(model):
for name, module in model.named_modules():
print(name, module)
# Assuming 'model' is an instance of LeNet5
model = LeNet5()
list_all_layers(model)
# Print the model's architecture
print(model)
# Programmatically check each layer
print("\nIterating over each layer in the model:")
for name, module in model.named_children():
print(f"{name}: {module}")