-
Notifications
You must be signed in to change notification settings - Fork 32
/
Copy pathretrieval.py
128 lines (100 loc) · 4.55 KB
/
retrieval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
import sys
import argparse
import logging
import time as t
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from modules import fpn
from commons import *
from utils import *
from train_picie import *
def initialize_classifier(args, n_query, centroids):
classifier = nn.Conv2d(args.in_dim, n_query, kernel_size=1, stride=1, padding=0, bias=False)
classifier = nn.DataParallel(classifier)
classifier = classifier.cuda()
if centroids is not None:
classifier.module.weight.data = centroids.unsqueeze(-1).unsqueeze(-1)
freeze_all(classifier)
return classifier
def get_testloader(args):
testset = EvalDataset(args.data_root, dataset=args.dataset, res=args.res1, split=args.val_type, mode='test', stuff=args.stuff, thing=args.thing)
testloader = torch.utils.data.DataLoader(testset,
batch_size=args.batch_size_eval,
shuffle=False,
num_workers=args.num_workers,
pin_memory=True,
collate_fn=collate_eval)
return testloader
def compute_dist(featmap, metric_function, euclidean_train=True):
centroids = metric_function.module.weight.data
if euclidean_train:
return - (1 - 2*metric_function(featmap)\
+ (centroids*centroids).sum(dim=1).unsqueeze(0)) # negative l2 squared
else:
return metric_function(featmap)
def get_nearest_neighbors(n_query, dataloader, model, classifier, k=10):
model.eval()
classifier.eval()
min_dsts = [[] for _ in range(n_query)]
min_locs = [[] for _ in range(n_query)]
min_imgs = [[] for _ in range(n_query)]
with torch.no_grad():
for indice, image, label in dataloader:
image = image.cuda(non_blocking=True)
feats = model(image)
feats = F.normalize(feats, dim=1, p=2)
dists = compute_dist(feats, classifier) # (B x C x H x W)
B, _, H, W = dists.shape
for c in range(n_query):
dst, idx = dists[:, c].flatten().topk(1)
idx = idx.item()
ib = idx//(H*W)
ih = idx%(H*W)//W
iw = idx%(H*W)%W
if len(min_dsts[c]) < k:
min_dsts[c].append(dst)
min_locs[c].append((ib, ih, iw))
min_imgs[c].append(indice[ib])
elif dst < max(min_dsts[c]):
imax = np.argmax(min_dsts[c])
min_dsts[c] = min_dsts[c][:imax] + min_dsts[c][imax+1:]
min_locs[c] = min_locs[c][:imax] + min_locs[c][imax+1:]
min_imgs[c] = min_imgs[c][:imax] + min_imgs[c][imax+1:]
min_dsts[c].append(dst)
min_locs[c].append((ib, ih, iw))
min_imgs[c].append(indice[ib])
loclist = min_locs
dataset = dataloader.dataset
imglist = [[dataset.transform_data(*dataset.load_data(dataset.imdb[i]), i, raw_image=True) for i in ids] for ids in min_imgs]
return imglist, loclist
if __name__ == '__main__':
args = parse_arguments()
# Use random seed.
fix_seed_for_reproducability(args.seed)
# Init model.
model = fpn.PanopticFPN(args)
model = nn.DataParallel(model)
model = model.cuda()
# Load weights.
checkpoint = torch.load(args.eval_path)
model.load_state_dict(checkpoint['state_dict'])
# Init classifier (for eval only.)
queries = torch.tensor(np.load('querys.npy')).cuda()
classifier = initialize_classifier(args, queries.size(0), queries)
# Prepare dataloader.
dataset = get_dataset(args, mode='eval_test')
dataloader = torch.utils.data.DataLoader(dataset,
batch_size=args.batch_size_test,
shuffle=False,
num_workers=args.num_workers,
pin_memory=True,
collate_fn=collate_eval,
worker_init_fn=worker_init_fn(args.seed))
# Retrieve 10-nearest neighbors.
imglist, loclist = get_nearest_neighbors(queries.size(0), dataloader, model, classifier, k=args.K_test)
# Save the result.
torch.save([imglist, loclist], args.save_root + '/picie_retrieval_result_coco.pkl')
print('-Done.', flush=True)