forked from facebookresearch/PartDistillation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
continuously_postprocess_dcrf.py
201 lines (160 loc) · 7.97 KB
/
continuously_postprocess_dcrf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import copy
import argparse
import torch
import numpy as np
import matplotlib.pyplot as plt
import pydensecrf.densecrf as dcrf
import pydensecrf.utils as dcrf_utils
import time
from pycocotools import mask as coco_mask
from detectron2.structures import BoxMode
from detectron2.data import detection_utils as utils
from detectron2.data import transforms as T
from detectron2.structures import BitMasks, Instances
def dense_crf(
image,
label,
n_labels,
p=0.7,
t=10,
sd1=3,
sd2=20,
sc=13,
compat1=3,
compat2=10
):
annotated_label = label.to(torch.int32).numpy()
colors, labels = np.unique(annotated_label, return_inverse=True)
c = image.shape[2]
h = image.shape[0]
w = image.shape[1]
d = dcrf.DenseCRF2D(w, h, n_labels)
U = dcrf_utils.unary_from_labels(labels, n_labels, gt_prob=p, zero_unsure=False)
d.setUnaryEnergy(U)
# This adds the color-independent term, features are the locations only.
feats = dcrf_utils.create_pairwise_gaussian(sdims=(sd1, sd1), shape=(h, w))
d.addPairwiseEnergy(feats, compat=compat1, kernel=dcrf.DIAG_KERNEL,
normalization=dcrf.NORMALIZE_SYMMETRIC)
# This adds the color-dependent term, i.e. features are (x,y,r,g,b).
feats = dcrf_utils.create_pairwise_bilateral(sdims=(sd2, sd2), schan=(sc, sc, sc),
img=image,
chdim=2)
d.addPairwiseEnergy(feats, compat=compat2,
kernel=dcrf.DIAG_KERNEL,
normalization=dcrf.NORMALIZE_SYMMETRIC)
Q = d.inference(t)
Q = np.array(Q).reshape((n_labels, h, w)).argmax(axis=0)
return Q
def proposals_to_coco_json(binary_mask):
"""
list[dict]: list of json annotations in COCO format.
"""
num_instance = len(binary_mask)
if num_instance == 0:
return []
rles = [coco_mask.encode(np.array(part_mask[:, :, None], order="F", dtype="uint8"))[0]
for part_mask in binary_mask]
for rle in rles:
# "counts" is an array encoded by coco_mask as a byte-stream. Python3's
# json writer which always produces strings cannot serialize a bytestream
# unless you decode it. Thankfully, utf-8 works out (which is also what
# the pycocotools/_mask.pyx does).
rle["counts"] = rle["counts"].decode("utf-8")
return [{"segmentation": rle} for rle in rles]
def get_argparse():
parser = argparse.ArgumentParser(description='Postprocess pseudo-labels')
parser.add_argument('--parallel_job_id', type=int, default=-1)
parser.add_argument('--num_parallel_jobs', type=int, default=-1)
parser.add_argument('--dataset_name', type=str, default="imagenet_1k_train")
parser.add_argument('--mining_metric', type=str, default="iou_based")
parser.add_argument('--dist_metric', type=str, default="dot")
parser.add_argument('--res', type=str, default="res3_res4")
parser.add_argument('--num_k', type=int, default=4)
parser.add_argument('--feat_norm', action="store_true", default=False)
parser.add_argument('--check_broken_files', action="store_true", default=False)
parser.add_argument('--root_folder_name', type=str, default='pseudo_labels')
parser.add_argument('--label_mode', type=str, default='max-gt-label')
parser.add_argument('--debug', action="store_true")
return parser.parse_args()
# dcrf is done on larger resolution for performance reason.
# the predictions already resized to 640.
augs = [T.ResizeScale(min_scale=1.0, max_scale=1.0, target_height=640, target_width=640),
# T.FixedSizeCrop(crop_size=(640, 640)),
]
if __name__ == "__main__":
args = get_argparse()
source_root = f"{args.root_folder_name}/part_labels/proposal_generation/{args.label_mode}/"
target_root = f"{args.root_folder_name}/part_labels/processed_proposals/{args.label_mode}/"
source_root = os.path.join(source_root, args.dataset_name, "detic", args.res, \
"{}_{}_norm_{}".format(args.dist_metric, args.num_k, args.feat_norm))
target_root = os.path.join(target_root, args.dataset_name, "detic", args.res, \
"{}_{}_norm_{}".format(args.dist_metric, args.num_k, args.feat_norm))
# partition the list of imagnet classes for each process.
code_list = os.listdir(source_root)
if args.num_parallel_jobs > 0:
num_total_classes = len(code_list)
num_classes_per_job = num_total_classes // args.num_parallel_jobs
num_remaining_classes = num_total_classes - args.num_parallel_jobs * num_classes_per_job
num_current_job_classes = num_classes_per_job
start_i = num_current_job_classes * args.parallel_job_id
end_i = num_current_job_classes * (args.parallel_job_id+1)
if args.parallel_job_id+1 == args.num_parallel_jobs:
end_i = num_total_classes
code_list = code_list[start_i:end_i]
# make folders
for code in code_list:
if not os.path.exists(os.path.join(target_root, code)):
os.makedirs(os.path.join(target_root, code))
# count total number of files to make
num_total = 0
for code in code_list:
num_total += len(os.listdir(os.path.join(source_root, code)))
t0 = time.time()
while True:
count = 0
for code in code_list:
fname_list = os.listdir(os.path.join(source_root, code))
for fname in fname_list:
if not os.path.exists(os.path.join(target_root, code, fname)):
try:
data = torch.load(os.path.join(source_root, code, fname), "cpu")
except:
print("broken file:", os.path.join(source_root, code, fname))
continue
mask = data["part_mask"]
if mask is not None:
image = utils.read_image(data["file_path"], format="RGB")
# Resizing
aug_input = T.AugInput(image)
aug_input, transforms = T.apply_transform_gens(augs, aug_input)
image = aug_input.image
bmask = []
for segm in mask:
bmask.append(coco_mask.decode(segm["segmentation"]))
bmask = torch.tensor(np.array(bmask))
assert image.shape[:2] == bmask.shape[1:], "tensor shapes do not match. ({} != {})"\
.format(image.shape[:2], bmask.shape[1:])
num_c = bmask.shape[0]
cmask = (bmask * (torch.arange(num_c) + 1)[:, None, None]).sum(0)
cmask = torch.tensor(dense_crf(image, cmask, num_c + 1))
o_cls = cmask.unique()
o_cls = o_cls[o_cls != 0]
bmask = torch.zeros(len(o_cls), *cmask.shape).bool()
for i, c in enumerate(o_cls):
bmask[i] = cmask == c
del data['part_mask']
data["part_mask"] = proposals_to_coco_json(bmask)
if args.debug:
assert False, "debug. "
torch.save(data, os.path.join(target_root, code, fname))
if count % 100 == 1:
print("{} ({:.2f} %) images processed on process {} ({:.2f} / image)"\
.format(count, count/num_total*100, args.parallel_job_id, (time.time()-t0)/count), flush=True)
count += 1
if count == num_total:
print("process {} done. (processed {}/{} images)".format(args.parallel_job_id, count, num_total), flush=True)
break