-
Notifications
You must be signed in to change notification settings - Fork 0
/
search.py
41 lines (34 loc) · 1.76 KB
/
search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import os
import time
import pickle
import argparse
import numpy as np
from tqdm import tqdm
from yoursearch import searchfunction
parser = argparse.ArgumentParser()
parser.description='please enter two parameters a and b ...'
parser.add_argument("-q", "--query_emb_path", help="this is the query embeddings path", type=str, default="./test_a/query_emb.npy")
parser.add_argument("-g", "--gallery_emb_path", help="this is the gallery embeddings path", type=str, default="./test_a/gallery_emb.npy")
parser.add_argument("-o", "--output_path", help="this is the output file path", type=str, default="./submissions/output.csv")
args = parser.parse_args()
K = 10
if not os.path.exists("./submissions"):
os.mkdir("./submissions")
csv_writer = open(args.output_path, 'w')
query_embeddings = np.load(args.query_emb_path)
gallery_embeddings = np.load(args.gallery_emb_path)
total_time = []
for ind, query_embedding in enumerate(tqdm(query_embeddings)):
start_time = time.time()
''' ⭐ ⭐ ⭐ ⬇ ⬇ ⬇ search part; your can replace here with any method to search ⬇ ⬇ ⬇ ⭐ ⭐ ⭐ '''
dist_jnds = []
for jnd, gallery_embedding in enumerate(gallery_embeddings):
dist = np.linalg.norm(query_embedding - gallery_embedding)
dist_jnds.append((dist, jnd))
topk_rank_list = list(map(lambda x:x[1], sorted(dist_jnds, key=lambda x: x[0])[:K]))
''' ⭐ ⭐ ⭐ ⬆ ⬆ ⬆ search part; your can replace here with any method to search ⬆ ⬆ ⬆ ⭐ ⭐ ⭐ '''
end_time = time.time()
total_time.append(end_time - start_time)
csv_writer.write(str(ind) + "," + ",".join(list(map(str, topk_rank_list))) + "\n")
topk_rank_lists = searchfunction(query_embeddings, gallery_embeddings, K=10)
pickle.dump(total_time, open("./submissions/total_time.pkl", "wb"))