-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathUNETDRIVE_TRAINING_TESTING.py
183 lines (141 loc) · 6.71 KB
/
UNETDRIVE_TRAINING_TESTING.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import os
import numpy as np
import glob
import PIL.Image as Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision import models
from torchsummary import summary
import torch.optim as optim
from time import time
import matplotlib.pyplot as plt
from IPython.display import clear_output
def train(model, opt, loss_fn, epochs, train_loader, val_loader):
# Grab a batch of validation data
X_test, Y_test = next(iter(val_loader))
for epoch in range(epochs):
print('* Epoch %d/%d' % (epoch + 1, epochs))
avg_loss = 0
model.train() # Set model to training mode
for X_batch, Y_batch in train_loader:
# Move data to device
X_batch = X_batch.to(device)
Y_batch = Y_batch.to(device)
# Zero the gradients
opt.zero_grad()
# Forward pass
Y_pred = model(X_batch)
# Calculate loss
loss = loss_fn(Y_batch, Y_pred) # Correct order: model output first; Corrected, first the real, then predicted
loss.backward() # Backward pass
opt.step() # Update weights
# Accumulate average loss
avg_loss += loss.item() / len(train_loader)
print(' - loss: %f' % avg_loss)
# Plot results every 10th epoch
if (epoch + 1) % 2 == 0:
model.eval() # Set model to evaluation mode
with torch.no_grad(): # Disable gradient calculation for evaluation
Y_hat = torch.sigmoid(model(X_test.to(device))).detach().cpu()
# Clear previous output and plot the results
clear_output(wait=True) # Only if using Jupyter Notebook
for k in range(4): # Display first 4 images
plt.subplot(2, 6, k + 1)
plt.imshow(np.rollaxis(X_test[k].cpu().numpy(), 0, 3), cmap='gray') # Move X_test to CPU
plt.title('Real')
plt.axis('off')
plt.subplot(2, 6, k + 7)
plt.imshow(Y_hat[k, 0].cpu().numpy(), cmap='gray') # Move Y_hat to CPU
plt.title('Output')
plt.axis('off')
plt.suptitle('%d / %d - loss: %f' % (epoch + 1, epochs, avg_loss))
plt.show()
import sys
import os
sys.path.append('/zhome/45/0/155089/Deeplearning_in_computer_vision/Segmentation_project/Asignments_DeepLearningForCV/')
from Unet_DRIVE import device
from Performance_Metrics import dice_coefficient, intersection_over_union, accuracy, sensitivity, specificity
import time
def test(model, test_loader, loss_fn):
model.eval() # Set the model to evaluation mode
test_loss = 0
all_y_true = []
all_y_pred = []
with torch.no_grad(): # Disable gradient calculation
for X_test_batch, Y_test_batch in test_loader:
X_test_batch = X_test_batch.to(device)
Y_test_batch = Y_test_batch.to(device)
Y_test_pred = model(X_test_batch)
loss = loss_fn(Y_test_batch,Y_test_pred ) # Compute test loss
test_loss += loss.item() # Accumulate test loss
all_y_true.append(Y_test_batch.cpu())
all_y_pred.append(Y_test_pred.cpu())
avg_test_loss = test_loss / len(test_loader)
print('Test Loss: %f' % avg_test_loss)
# Concatenate all predictions and ground truths
all_y_true = torch.cat(all_y_true)
all_y_pred = torch.cat(all_y_pred)
# Calculate metrics
dice = dice_coefficient(all_y_true, all_y_pred)
iou = intersection_over_union(all_y_true, all_y_pred)
acc = accuracy(all_y_true, all_y_pred)
sens = sensitivity(all_y_true, all_y_pred)
spec = specificity(all_y_true, all_y_pred)
# Print metrics
print(f'Dice: {dice:.4f}, IoU: {iou:.4f}, Accuracy: {acc:.4f}, Sensitivity: {sens:.4f}, Specificity: {spec:.4f}')
# Pause for a moment to allow user to read metrics
#time.sleep(5) # Adjust time as needed, here it waits for 5 seconds
# Visualization of results
clear_output(wait=True) # Clear previous output
X_test_batch, Y_test_batch = next(iter(test_loader))
Y_test_pred = F.sigmoid(model(X_test_batch.to(device))).detach().cpu()
#Plot the first 4 images and their predictions
for k in range(4): # For example, visualize the first 4 elements
plt.subplot(2, 6, k + 1)
plt.imshow(np.rollaxis(X_test_batch[k].cpu().numpy(), 0, 3), cmap='gray')
#plt.title('Real')
plt.axis('off')
plt.tight_layout()
plt.subplot(2, 6, k + 7)
plt.imshow(Y_test_pred[k, 0], cmap='gray')
#plt.title('Output')
plt.axis('off')
plt.tight_layout()
plt.suptitle('Test - Loss: %f' % avg_test_loss)
plt.show() # This will block execution until you close the plot
# Create lovely plot of preds, GT and og images from testing
def visualize_test_predictions(model, test_loader, device, num_images=4):
model.eval() # Set the model to evaluation mode
X_test_batch, Y_test_batch = next(iter(test_loader)) # Fetch a batch of test data
X_test_batch = X_test_batch.to(device)
Y_test_batch = Y_test_batch.to(device)
# Get predictions
with torch.no_grad():
Y_test_pred = model(X_test_batch)
Y_test_pred = torch.sigmoid(Y_test_pred).detach().cpu() # Apply sigmoid and move to CPU
# Plotting
fig, axs = plt.subplots(3, num_images, figsize=(16, 10)) # Larger figure for clearer view
# Set a nicer colormap
cmap = 'inferno' # Try using 'plasma', 'magma', 'inferno', etc.
for k in range(num_images): # Visualize the specified number of images
# Original images
axs[0, k].imshow(np.rollaxis(X_test_batch[k].cpu().numpy(), 0, 3), cmap='gray')
axs[0, k].axis('off')
# Ground truth images
axs[1, k].imshow(Y_test_batch[k].cpu().numpy().squeeze(), cmap='gray') # Squeeze to remove the channel dimension
axs[1, k].axis('off')
# Predicted images with a colorful colormap
img = axs[2, k].imshow(Y_test_pred[k, 0], cmap='gray') # Apply chosen colormap
axs[2, k].axis('off')
# Add rotated labels outside the subplots using fig.text
fig.text(0.04, 0.82, 'Original', fontsize=30, rotation=90, va='center', color = 'white')
fig.text(0.04, 0.5, 'Ground Truth', fontsize=30, rotation=90, va='center' , color = 'white')
fig.text(0.04, 0.18, 'Prediction', fontsize=30, rotation=90, va='center', color = 'white')
# Adjust layout to avoid overlap
plt.subplots_adjust(wspace=0.1, hspace=0.2) # Adjust spacing between plots
plt.tight_layout(rect=[0, 0, 0.9, 1]) # Leave space for colorbar on the right
plt.show()