forked from Scalsol/mega.pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
rpn.py
262 lines (216 loc) · 9.71 KB
/
rpn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch
import torch.nn.functional as F
from torch import nn
from mega_core.modeling import registry
from mega_core.modeling.box_coder import BoxCoder
from mega_core.modeling.rpn.retinanet.retinanet import build_retinanet
from .loss import make_rpn_loss_evaluator
from .anchor_generator import make_anchor_generator
from .inference import make_rpn_postprocessor
class RPNHeadConvRegressor(nn.Module):
"""
A simple RPN Head for classification and bbox regression
"""
def __init__(self, cfg, in_channels, num_anchors):
"""
Arguments:
cfg : config
in_channels (int): number of channels of the input feature
num_anchors (int): number of anchors to be predicted
"""
super(RPNHeadConvRegressor, self).__init__()
self.cls_logits = nn.Conv2d(in_channels, num_anchors, kernel_size=1, stride=1)
self.bbox_pred = nn.Conv2d(
in_channels, num_anchors * 4, kernel_size=1, stride=1
)
for l in [self.cls_logits, self.bbox_pred]:
torch.nn.init.normal_(l.weight, std=0.01)
torch.nn.init.constant_(l.bias, 0)
def forward(self, x):
assert isinstance(x, (list, tuple))
logits = [self.cls_logits(y) for y in x]
bbox_reg = [self.bbox_pred(y) for y in x]
return logits, bbox_reg
class RPNHeadFeatureSingleConv(nn.Module):
"""
Adds a simple RPN Head with one conv to extract the feature
"""
def __init__(self, cfg, in_channels):
"""
Arguments:
cfg : config
in_channels (int): number of channels of the input feature
"""
super(RPNHeadFeatureSingleConv, self).__init__()
self.conv = nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1
)
for l in [self.conv]:
torch.nn.init.normal_(l.weight, std=0.01)
torch.nn.init.constant_(l.bias, 0)
self.out_channels = in_channels
def forward(self, x):
assert isinstance(x, (list, tuple))
x = [F.relu(self.conv(z)) for z in x]
return x
@registry.RPN_HEADS.register("SingleConvRPNHead")
class RPNHead(nn.Module):
"""
Adds a simple RPN Head with classification and regression heads
"""
def __init__(self, cfg, in_channels, num_anchors):
"""
Arguments:
cfg : config
in_channels (int): number of channels of the input feature
num_anchors (int): number of anchors to be predicted
"""
super(RPNHead, self).__init__()
self.conv = nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1
)
self.cls_logits = nn.Conv2d(in_channels, num_anchors, kernel_size=1, stride=1)
self.bbox_pred = nn.Conv2d(
in_channels, num_anchors * 4, kernel_size=1, stride=1
)
for l in [self.conv, self.cls_logits, self.bbox_pred]:
torch.nn.init.normal_(l.weight, std=0.01)
torch.nn.init.constant_(l.bias, 0)
def forward(self, x):
logits = []
bbox_reg = []
for feature in x:
t = F.relu(self.conv(feature))
logits.append(self.cls_logits(t))
bbox_reg.append(self.bbox_pred(t))
return logits, bbox_reg
class RPNModule(torch.nn.Module):
"""
Module for RPN computation. Takes feature maps from the backbone and outputs
RPN proposals and losses. Works for both FPN and non-FPN.
"""
def __init__(self, cfg, in_channels):
super(RPNModule, self).__init__()
self.cfg = cfg.clone()
anchor_generator = make_anchor_generator(cfg)
rpn_head = registry.RPN_HEADS[cfg.MODEL.RPN.RPN_HEAD]
head = rpn_head(
cfg, in_channels, anchor_generator.num_anchors_per_location()[0]
)
rpn_box_coder = BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
box_selector_train = make_rpn_postprocessor(cfg, rpn_box_coder, is_train=True)
box_selector_test = make_rpn_postprocessor(cfg, rpn_box_coder, is_train=False)
loss_evaluator = make_rpn_loss_evaluator(cfg, rpn_box_coder)
self.anchor_generator = anchor_generator
self.head = head
self.box_selector_train = box_selector_train
self.box_selector_test = box_selector_test
self.loss_evaluator = loss_evaluator
def forward(self, images, features, targets=None):
"""
Arguments:
images (ImageList): images for which we want to compute the predictions
features (list[Tensor]): features computed from the images that are
used for computing the predictions. Each tensor in the list
correspond to different feature levels
targets (list[BoxList): ground-truth boxes present in the image (optional)
Returns:
boxes (list[BoxList]): the predicted boxes from the RPN, one BoxList per
image.
losses (dict[Tensor]): the losses for the model during training. During
testing, it is an empty dict.
"""
objectness, rpn_box_regression = self.head(features)
anchors = self.anchor_generator(images, features)
if self.training:
return self._forward_train(anchors, objectness, rpn_box_regression, targets)
else:
return self._forward_test(anchors, objectness, rpn_box_regression)
def _forward_train(self, anchors, objectness, rpn_box_regression, targets):
if self.cfg.MODEL.RPN_ONLY:
# When training an RPN-only model, the loss is determined by the
# predicted objectness and rpn_box_regression values and there is
# no need to transform the anchors into predicted boxes; this is an
# optimization that avoids the unnecessary transformation.
boxes = anchors
else:
# For end-to-end models, anchors must be transformed into boxes and
# sampled into a training batch.
with torch.no_grad():
boxes = self.box_selector_train(
anchors, objectness, rpn_box_regression, targets
)
loss_objectness, loss_rpn_box_reg = self.loss_evaluator(
anchors, objectness, rpn_box_regression, targets
)
losses = {
"loss_objectness": loss_objectness,
"loss_rpn_box_reg": loss_rpn_box_reg,
}
return boxes, losses
def _forward_test(self, anchors, objectness, rpn_box_regression):
boxes = self.box_selector_test(anchors, objectness, rpn_box_regression)
if self.cfg.MODEL.RPN_ONLY:
# For end-to-end models, the RPN proposals are an intermediate state
# and don't bother to sort them in decreasing score order. For RPN-only
# models, the proposals are the final output and we return them in
# high-to-low confidence order.
inds = [
box.get_field("objectness").sort(descending=True)[1] for box in boxes
]
boxes = [box[ind] for box, ind in zip(boxes, inds)]
return boxes, {}
class RPNWithRefModule(RPNModule):
"""
Module for RPN computation. Takes feature maps from the backbone and outputs
RPN proposals and losses. Works for both FPN and non-FPN.
"""
def __init__(self, cfg, in_channels):
super(RPNWithRefModule, self).__init__(cfg, in_channels)
rpn_box_coder = BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
box_selector_ref = make_rpn_postprocessor(cfg, rpn_box_coder, is_train=True, version="ref")
self.box_selector_ref = box_selector_ref
def forward(self, images, features, targets=None, version="key"):
"""
Arguments:
images (ImageList): images for which we want to compute the predictions
features (list[Tensor]): features computed from the images that are
used for computing the predictions. Each tensor in the list
correspond to different feature levels
targets (list[BoxList): ground-truth boxes present in the image (optional)
version (str): whether the current image is key frame or reference frame
Returns:
boxes (list[BoxList]): the predicted boxes from the RPN, one BoxList per
image.
losses (dict[Tensor]): the losses for the model during training. During
testing, it is an empty dict.
"""
objectness, rpn_box_regression = self.head(features)
anchors = self.anchor_generator(images, features)
if version == "key":
if self.training:
return self._forward_train(anchors, objectness, rpn_box_regression, targets)
else:
return self._forward_test(anchors, objectness, rpn_box_regression)
else:
return self._forward_ref(anchors, objectness, rpn_box_regression)
def _forward_ref(self, anchors, objectness, rpn_box_regression):
with torch.no_grad():
boxes = self.box_selector_ref(anchors, objectness, rpn_box_regression)
return boxes
def build_rpn(cfg, in_channels):
"""
This gives the gist of it. Not super important because it doesn't change as much
"""
if cfg.MODEL.RETINANET_ON:
return build_retinanet(cfg, in_channels)
if cfg.MODEL.VID.ENABLE:
FUNC_DICT = {"base": RPNModule,
"fgfa": RPNModule,
"dff": RPNModule,
"rdn": RPNWithRefModule,
"mega": RPNWithRefModule}
func = FUNC_DICT[cfg.MODEL.VID.METHOD]
return func(cfg, in_channels)
return RPNModule(cfg, in_channels)