-
Notifications
You must be signed in to change notification settings - Fork 3
/
COCO_generate.py
96 lines (79 loc) · 2.51 KB
/
COCO_generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import numpy as np
import os
from pycocotools.coco import COCO
import os.path
import json
import random
import uuid
def getLabel(label):
result = ''
for l in label:
result += str(l) + ' '
return result
seed = uuid.uuid1().int
random.seed(seed)
train_json = '/hhd12306/zhuxiaosu/cocoapi/PythonAPI/instances_train2014.json'
val_json = '/hhd12306/zhuxiaosu/cocoapi/PythonAPI/instances_val2014.json'
train_prefix = '/hhd12306/chendaiyuan/Data/coco/coco_official/train2014/'
val_prefix = '/hhd12306/chendaiyuan/Data/coco/coco_official/val2014/'
coco = COCO(train_json)
val = COCO(val_json)
cats = coco.getCatIds()
all_dict = dict()
j = 0
for cat in cats:
imgs = coco.getImgIds(catIds=[cat])
files = coco.loadImgs(imgs)
i = 0
for img in files:
fname = os.path.join(train_prefix, img['file_name'])
if not os.path.isfile(fname):
raise FileNotFoundError('{} not exists'.format(img))
if imgs[i] not in all_dict:
all_dict[imgs[i]] = {'path': fname, 'label': [0] * len(cats)}
all_dict[imgs[i]]['label'][j] = 1
i += 1
j += 1
cats = val.getCatIds()
j = 0
for cat in cats:
imgs = val.getImgIds(catIds=[cat])
files = val.loadImgs(imgs)
i = 0
for img in files:
fname = os.path.join(val_prefix, img['file_name'])
if not os.path.isfile(fname):
raise FileNotFoundError('{} not exists'.format(img))
if imgs[i] not in all_dict:
all_dict[imgs[i]] = {'path': fname, 'label': [0] * len(cats)}
all_dict[imgs[i]]['label'][j] = 1
i += 1
j += 1
with open('/hhd12306/zhuxiaosu/DSQ_NUS/data/coco/all_imgs.json', 'w') as fp:
json.dump(all_dict, fp)
all_key = list(all_dict.keys())
random.shuffle(all_key)
query_dict = list()
database_dict = list()
i = 0
for key in all_key:
if i < 5000:
query_dict.append(all_dict[key]['path'] + ' ' + getLabel(all_dict[key]['label']) + '\r\n')
else:
database_dict.append(all_dict[key]['path'] + ' ' + getLabel(all_dict[key]['label']) + '\r\n')
i += 1
with open('/hhd12306/zhuxiaosu/DSQ_NUS/data/coco/query.txt', 'w') as fp:
fp.writelines(query_dict)
with open('/hhd12306/zhuxiaosu/DSQ_NUS/data/coco/database.txt', 'w') as fp:
fp.writelines(database_dict)
train_dict = list()
random.shuffle(database_dict)
i = 0
for key in database_dict:
if i < 10000:
train_dict.append(key)
else:
break
i += 1
with open('/hhd12306/zhuxiaosu/DSQ_NUS/data/coco/train.txt', 'w') as fp:
fp.writelines(train_dict)