-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathtest_de.py
83 lines (68 loc) · 3.42 KB
/
test_de.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
import sys
import math
import torch
import pickle
import argparse
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torchvision.models as models
from loguru import logger
from tqdm import tqdm
from torchvision import transforms
from lib.utils.exp import (
get_model,
get_modeldir_ens,
get_transform,
get_mean,
get_std,
get_dataloader,
)
from lib.model import resnet,lenet
from lib.utils import split_dataloader
parser = argparse.ArgumentParser(description='Description of your program')
parser.add_argument('-i','--ind', type=str, help='in distribution dataset', required=True)
parser.add_argument('-o','--ood', type=str, help='out of distribution dataset', required=True)
parser.add_argument('-m','--model_arch', type=str, help='model architecture', required=True)
parser.add_argument('-b','--batch_size', type=int, default=64)
parser.add_argument('--model_num',type=int, help='the number of classifiers for ensemble',default=5)
parser.add_argument('--dataroot',type=str, help='datatset stroage directory',default='/data/datasets')
args = vars(parser.parse_args())
print(args)
modeldir = get_modeldir_ens(args['ind'], args['model_arch'])
ensemble_num = args['model_num']
# ----- load dataset -----
transform = get_transform(args['ind'])
std = get_std(args['ind'])
ind_test_loader = get_dataloader(args['ind'], transform, "test",dataroot=args['dataroot'],batch_size=args['batch_size'])
ood_test_loader = get_dataloader(args['ood'], transform, "test",dataroot=args['dataroot'],batch_size=args['batch_size'])
ind_dataloader_val_for_train, ind_dataloader_val_for_test, ind_dataloader_test = split_dataloader(args['ind'], ind_test_loader, [500,500,-1])
ood_dataloader_val_for_train, ood_dataloader_val_for_test, ood_dataloader_test = split_dataloader(args['ood'], ood_test_loader, [500,500,-1])
# ----- Calculating and averaging maximum softmax probabilities -----
from lib.inference.ODIN import get_ODIN_score
best_temperature = 1.0
best_magnitude = 0.0
ind_ensemble_val = []
ood_ensemble_val = []
ind_ensemble_test = []
ood_ensemble_test = []
for id, ckpt in enumerate(os.listdir(modeldir)[:ensemble_num]):
model_path = modeldir + args['ind'] + '_' + args['model_arch'] + f'_{id}.pth'
model = get_model(args['ind'], args['model_arch'], target_model_path=model_path)
ind_scores_val_for_train = get_ODIN_score(model, ind_dataloader_val_for_train, best_magnitude, best_temperature, std=std)
ood_scores_val_for_train = get_ODIN_score(model, ood_dataloader_val_for_train, best_magnitude, best_temperature, std=std)
ind_ensemble_val.append(ind_scores_val_for_train)
ood_ensemble_val.append(ood_scores_val_for_train)
ind_scores_test = get_ODIN_score(model, ind_dataloader_test, best_magnitude, best_temperature, std=std)
ood_scores_test = get_ODIN_score(model, ood_dataloader_test, best_magnitude, best_temperature, std=std)
ind_ensemble_test.append(ind_scores_test)
ood_ensemble_test.append(ood_scores_test)
take_mean_and_reshape = lambda x: np.array(x).mean(axis=0).reshape(-1, 1)
ind_val, ood_val, ind_test, ood_test = map(take_mean_and_reshape, [ind_ensemble_val, ood_ensemble_val, ind_ensemble_test, ood_ensemble_test])
# ----- Train OoD detector using validation data -----
from lib.metric import get_metrics, train_lr
lr = train_lr(ind_val, ood_val)
# ----- Calculating metrics using test data -----
metrics = get_metrics(lr, ind_test, ood_test, acc_type="best")
print("metrics: ", metrics)