-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtemp_test.py
90 lines (69 loc) · 2.84 KB
/
temp_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
import sys
import random
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torchvision.utils
import imgaug as ia
from torch.utils.data import DataLoader,Dataset
from torch.autograd import Variable
from imgaug import augmenters as iaa
from PIL import Image
from torchsummaryX import summary
from siamese_network_defect import SiameseNetwork, DefectDataset, imshow, Augmenter, Config
# Visualize feature maps
activation = {}
def get_activation(name):
def hook(model, input, output):
activation[name] = output.detach()
return hook
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
net = SiameseNetwork(size=(250, 250))
net.summary()
net.cnn1.register_forward_hook(get_activation('conv1'))
if device.type == 'cpu':
model = torch.nn.DataParallel(net)
else:
model = torch.nn.DataParallel(net, device_ids=[0, 1]).cuda()
model.to(device)
model.load_state_dict(torch.load("./result.pth.tar", map_location=device)["state_dict"])
model.eval()
seq = iaa.Sequential([
iaa.Resize({"height": Config.RESIZE[0], "width": Config.RESIZE[1]})
])
composed = transforms.Compose([Augmenter(seq)])
dataset = DefectDataset(root=Config.testing_dir, transform=composed)
vis_dataloader = DataLoader(dataset,
shuffle=True,
num_workers=0,
batch_size=8)
dataiter = iter(vis_dataloader)
example_batch = next(dataiter)
concatenated = torch.cat((example_batch[0],example_batch[1]),0)
imshow(torchvision.utils.make_grid(concatenated))
print(example_batch[2].numpy())
print(example_batch[0].shape)
test_dataloader = DataLoader(dataset, num_workers=0, batch_size=1, shuffle=True)
for j in range(2):
dataiter = iter(test_dataloader)
for i in range(len(dataset)):
x0, x1, label2 = next(dataiter)
concatenated = torch.cat((x0, x1), 0)
output1, output2 = model(Variable(x0).to(device), Variable(x1).to(device))
euclidean_distance = F.pairwise_distance(output1, output2)
distance = euclidean_distance.item()
imshow(torchvision.utils.make_grid(concatenated), 'Pred : {}, Label : {}, Dissimilarity: {:.2f}'
.format("Same" if distance < 1.5 else "Differ",
"Same" if label2 == 0 else "Differ",
euclidean_distance.item()),
should_save=True, name=str(j)+str(i))
act = activation['conv1'].squeeze()
for idx in range(act.size(0)):
plt.figure()
plt.imshow(act[idx], cmap='gray')
plt.savefig(str(j)+str(i)+"_"+str(idx)+"_activation")