-
Notifications
You must be signed in to change notification settings - Fork 3
/
nus_pick.py
85 lines (61 loc) · 1.58 KB
/
nus_pick.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# coding=utf-8
from __future__ import absolute_import
from __future__ import print_function, division
import os
from os import path
import numpy as np
import pickle as p
from operator import add
with open('./nus21/database.txt', 'r') as fp:
all = fp.readlines()
with open('./nus21/query.txt', 'r') as fp:
r = fp.readlines()
all = all + r
a = [i.strip().split()[1:] for i in all]
label = []
for l in a:
label.append([int(j) for j in l])
flag = [0] * len(label)
index = np.random.permutation(len(label))
current = 0
count = 0
query = []
for i in range(21):
while True:
if label[index[current]][i] == 1:
flag[index[current]] = 1
query.append(all[index[current]])
count += 1
current += 1
if count >= 100:
count = 0
break
remain = []
current = 0
for i in flag:
if i == 0:
remain.append(all[current])
current += 1
with open('./nus_eccv/query.txt', 'w') as fp:
fp.writelines(query)
with open('./nus_eccv/database.txt','w') as fp:
fp.writelines(remain)
a = [i.strip().split()[1:] for i in remain]
label = []
for l in a:
label.append([int(j) for j in l])
train = []
for i in range(20, -1, -1):
count = 0
while True:
for j in range(len(label)):
if label[j][i] == 1:
label.remove(label[j])
train.append(remain[j])
remain.remove(remain[j])
count += 1
break
if count >= 500:
break
with open('./nus_eccv/train.txt','w') as fp:
fp.writelines(train)