-
Notifications
You must be signed in to change notification settings - Fork 0
/
roc_auc_folio.py
80 lines (66 loc) · 2.88 KB
/
roc_auc_folio.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from sklearn.metrics import roc_curve, auc, roc_auc_score
from sklearn.preprocessing import label_binarize
from scipy import interp
from itertools import cycle
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
from sklearn.metrics import classification_report
from model import DeepHybridnet
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
def main():
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
test_dataset = ImageFolder(root='G:/vineet/medicinal_plant/Code/V1/folio_data/split_dataset/test', transform=transform)
test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=2)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_classes = 30
model = DeepHybridnet(num_classes=num_classes).to(device)
model.load_state_dict(torch.load('G:/vineet/medicinal_plant/Code/V1/weights/mode.pth', map_location=device))
model.eval()
# Prepare to capture true labels and predictions
y_test = []
y_score = []
with torch.no_grad():
for images, labels in test_loader:
images = images.to(device)
outputs = model(images)
outputs = torch.softmax(outputs, dim=1)
y_test.append(labels.cpu().numpy())
y_score.append(outputs.cpu().numpy()[0])
y_test = np.concatenate(y_test, axis=0)
y_score = np.array(y_score)
# Binarize the labels
y_test_bin = label_binarize(y_test, classes=[*range(num_classes)])
# Compute ROC curve and ROC area for each class
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(num_classes):
fpr[i], tpr[i], _ = roc_curve(y_test_bin[:, i], y_score[:, i])
roc_auc[i] = auc(fpr[i], tpr[i])
# Compute micro-average ROC curve and ROC area
fpr["micro"], tpr["micro"], _ = roc_curve(y_test_bin.ravel(), y_score.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
# Plot of a ROC curve for a specific class
plt.figure()
plt.plot(fpr["micro"], tpr["micro"], label='micro-average ROC curve (area = {0:0.2f})'.format(roc_auc["micro"]))
for i in range(num_classes):
plt.plot(fpr[i], tpr[i], label='ROC curve of class {0} (area = {1:0.2f})'.format(i, roc_auc[i]))
plt.plot([0, 1], [0, 1], 'k--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Some extension of Receiver operating characteristic to multi-class')
plt.legend(loc="lower right")
plt.show()
if __name__ == '__main__':
main()