-
Notifications
You must be signed in to change notification settings - Fork 323
/
topformer.py
505 lines (427 loc) · 17.8 KB
/
topformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
# Copyright (c) 2021 PPViT Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
TopFormer in Paddle
A Paddle Implementation of Token Pyramid Transformer(TopFormer) as described in:
Note: This implementation only contains the image classification model.
"TopFormer: Token Pyramid Transformer for Mobile Semantic Segmentation"
- Paper Link: https://arxiv.org/pdf/2204.05525.pdf
"""
import paddle
import paddle.nn as nn
from droppath import DropPath
def _make_divisible(v, divisor, min_value=None):
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8
It can be seen here:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
:param v:
:param divisor:
:param min_value:
:return:
"""
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
class Identity(nn.Layer):
""" Identity layer
The output of this layer is the input without any change.
Use this layer to avoid if condition in some forward methods
"""
def forward(self, inputs):
return inputs
class Mlp(nn.Layer):
""" MLP module"""
def __init__(self, embed_dim, mlp_ratio, dropout=0.):
super().__init__()
#w_attr_1, b_attr_1 = self._init_weights_linear()
hidden_dim = int(embed_dim * mlp_ratio)
self.fc1 = ConvNormAct(embed_dim, hidden_dim, kernel_size=1, act=None)
self.dwconv = nn.Conv2D(hidden_dim, hidden_dim, 3, 1, 1, groups=hidden_dim)
self.fc2 = ConvNormAct(hidden_dim, embed_dim, kernel_size=1, act=None)
self.act = nn.ReLU6()
self.dropout = nn.Dropout(dropout)
def _init_weights_linear(self):
weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.TruncatedNormal(std=.02))
bias_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(0.0))
return weight_attr, bias_attr
def forward(self, x):
x = self.fc1(x)
x = self.dwconv(x)
x = self.act(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
class Attention(nn.Layer):
def __init__(self,
embed_dim,
key_dim,
num_heads,
attn_ratio=2):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.attn_head_size = key_dim
self.all_head_size = self.attn_head_size * num_heads
self.dh = int(self.attn_head_size * attn_ratio) * num_heads
self.q = ConvNormAct(embed_dim, self.all_head_size, kernel_size=1, act=None)
self.k = ConvNormAct(embed_dim, self.all_head_size, kernel_size=1, act=None)
self.v = ConvNormAct(embed_dim, self.dh, kernel_size=1, act=None)
self.scales = self.attn_head_size ** -0.5
self.proj = nn.Sequential(*[
nn.ReLU6(),
ConvNormAct(self.dh, self.embed_dim, kernel_size=1, act=None)])
self.softmax = nn.Softmax(-1)
def transpose_multihead(self, x):
# in_shape: [batch_size, all_head_size, H', W']
N, C, H, W = x.shape
x = x.reshape([N, self.num_heads, -1, H, W])
x = x.flatten(-2) # [N, num_heads, attn_head_size, H*W]
x = x.transpose([0, 1, 3, 2]) #[N, num_heads, H*W, attn_head_size]
return x
def forward(self, x):
N, C, H, W = x.shape
q = self.q(x) # N, C, H', W'
q = self.transpose_multihead(q)
k = self.k(x)
k = self.transpose_multihead(k)
v = self.v(x) #
v = self.transpose_multihead(v)
#q = q * self.scales
attn = paddle.matmul(q, k, transpose_y=True)
attn = self.softmax(attn)
#attn = self.attn_dropout(attn)
z = paddle.matmul(attn, v)
z = z.transpose([0, 1, 3, 2])
z = z.reshape([N, self.dh, H, W])
z = self.proj(z)
return z
class EncoderLayer(nn.Layer):
def __init__(self,
embed_dim,
key_dim,
num_heads=8,
mlp_ratio=2.0,
attn_ratio=2.0,
dropout=0.,
attention_dropout=0.,
droppath=0.):
super().__init__()
#self.attn_norm = nn.LayerNorm(embed_dim, weight_attr=w_attr_1, bias_attr=b_attr_1)
self.attn = Attention(embed_dim, key_dim, num_heads, attn_ratio)
self.drop_path = DropPath(droppath) if droppath > 0. else Identity()
#self.mlp_norm = nn.LayerNorm(embed_dim, weight_attr=w_attr_2, bias_attr=b_attr_2)
self.mlp = Mlp(embed_dim, mlp_ratio, dropout)
def forward(self, x):
h = x
#x = self.attn_norm(x)
x = self.attn(x)
x = self.drop_path(x)
x = h + x
h = x
#x = self.mlp_norm(x)
x = self.mlp(x)
x = self.drop_path(x)
x = x + h
return x
class Transformer(nn.Layer):
def __init__(self,
embed_dim,
key_dim,
num_heads,
depth,
qkv_bias=True,
mlp_ratio=2.0,
attn_ratio=2.0,
dropout=0.,
attention_dropout=0.,
droppath=0.):
super().__init__()
depth_decay = [x.item() for x in paddle.linspace(0, droppath, depth)]
layer_list = []
for i in range(depth):
layer_list.append(EncoderLayer(embed_dim,
key_dim,
num_heads,
mlp_ratio,
attn_ratio,
dropout,
attention_dropout,
droppath))
self.layers = nn.LayerList(layer_list)
#w_attr_1, b_attr_1 = _init_weights_layernorm()
#self.norm = nn.LayerNorm(embed_dim,
# weight_attr=w_attr_1,
# bias_attr=b_attr_1,
# epsilon=1e-6)
def forward(self, x):
for idx, layer in enumerate(self.layers):
x = layer(x)
#out = self.norm(x)
#return out
return x
class ConvNormAct(nn.Layer):
"""Layer ops: Conv2D -> BatchNorm2D -> ReLU"""
def __init__(self,
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=0,
bias_attr=False,
groups=1,
act=nn.ReLU(),
norm=nn.BatchNorm2D):
super().__init__()
self.conv = nn.Conv2D(in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
weight_attr=paddle.ParamAttr(initializer=nn.initializer.KaimingUniform()),
bias_attr=bias_attr)
self.norm = Identity() if norm is None else norm(out_channels)
self.act = Identity() if act is None else act
def forward(self, inputs):
out = self.conv(inputs)
out = self.norm(out)
out = self.act(out)
return out
class MobileV2Block(nn.Layer):
"""Mobilenet v2 InvertedResidual block, hacked from torchvision"""
def __init__(self, inp, oup, kernel_size=3, stride=1, expansion=4):
super().__init__()
self.stride = stride
assert stride in [1, 2]
hidden_dim = int(round(inp * expansion))
self.use_res_connect = self.stride == 1 and inp == oup
layers = []
if expansion != 1:
layers.append(ConvNormAct(inp, hidden_dim, kernel_size=1))
layers.extend([
# dw
ConvNormAct(hidden_dim,
hidden_dim,
kernel_size=kernel_size,
stride=stride,
groups=hidden_dim,
padding=kernel_size//2),
# pw-linear
nn.Conv2D(hidden_dim, oup, 1, 1, 0, bias_attr=False),
nn.BatchNorm2D(oup),
])
self.conv = nn.Sequential(*layers)
self.out_channels = oup
def forward(self, x):
if self.use_res_connect:
return x + self.conv(x)
return self.conv(x)
class TokenPyramidModule(nn.Layer):
def __init__(self,
cfgs,
out_indices,
input_channel=16,
width_mult=1.):
super().__init__()
self.out_indices = out_indices
self.stem = ConvNormAct(3, input_channel, kernel_size=3, stride=2, padding=1)
self.layers = nn.LayerList()
for idx, (k, t, c, s) in enumerate(cfgs):
output_channel = _make_divisible(c * width_mult, 8)
expand_size = t * input_channel
expand_size = _make_divisible(expand_size * width_mult, 8)
self.layers.append(MobileV2Block(input_channel,
output_channel,
kernel_size=k,
stride=s,
expansion=t))
input_channel = output_channel
def forward(self, x):
outs = []
x = self.stem(x)
for i, layer in enumerate(self.layers):
x = layer(x)
if i in self.out_indices:
outs.append(x)
return outs
class PyramidPoolAgg(nn.Layer):
def __init__(self, stride):
super().__init__()
self.stride = stride
def forward(self, inputs):
N, C, H, W = inputs[-1].shape
H = (H - 1) // self.stride + 1
W = (W - 1) // self.stride + 1
return paddle.concat(
[nn.functional.adaptive_avg_pool2d(inp, (H, W)) for inp in inputs], axis=1)
class InjectionMultiSum(nn.Layer):
def __init__(self, in_channels, out_channels):
super().__init__()
self.local_embedding = ConvNormAct(in_channels, out_channels, kernel_size=1, act=None)
self.global_embedding = ConvNormAct(in_channels, out_channels, kernel_size=1, act=None)
self.global_act = ConvNormAct(in_channels, out_channels, kernel_size=1, act=None)
self.act = nn.Hardsigmoid()
def forward(self, x_local, x_global):
N, C, H, W = x_local.shape
local_feature = self.local_embedding(x_local)
global_act = self.global_act(x_global)
global_act = self.act(global_act)
global_act = nn.functional.interpolate(
global_act, size=(H, W), mode='bilinear', align_corners=False)
global_feature = self.global_embedding(x_global)
global_feature = nn.functional.interpolate(
global_feature, size=(H, W), mode='bilinear', align_corners=False)
out = local_feature * sigmoid_act + global_feature
return out
class InjectionMultiSumCBR(InjectionMultiSum):
def __init__(self, in_channels, out_channels):
super().__init__(in_channels, out_channels)
self.local_embedding = ConvNormAct(in_channels, out_channels, kernel_size=1)
self.global_embedding = ConvNormAct(in_channels, out_channels, kernel_size=1)
self.global_act = ConvNormAct(in_channels, out_channels, kernel_size=1, act=None, norm=None)
class FuseBlockSum(nn.Layer):
def __init__(self, in_channels, out_channels, act=nn.ReLU6()):
super().__init__()
self.local_embedding = ConvNormAct(in_channels,
out_channels,
kernel_size=1,
act=None)
self.global_embedding = ConvNormAct(in_channels,
out_channels,
kernel_size=1,
act=None)
self.act = Identity() if act is None else act
def forward_features(self, x_local, x_global):
N, C, H, W = x_local.shape
local_feature = self.local_embedding(x_local)
global_feature = self.global_embedding(x_global)
global_feature = self.act(global_feature)
global_feature = nn.functional.interpolate(
global_feature, size=(H, W), mode='bilinear', align_corners=False)
return local_feature, global_feature
def forward(self, x_local, x_global):
local_features, global_features = self.forward_features(x_local, x_global)
out = local_feature + global_feature
return out
class FuseBlockMulti(nn.Layer):
def __init__(self, in_channels, out_channels, act=nn.Hardsigmoid()):
super().__init__(in_channels, out_channels, act)
pass
def forward(self, x_local, x_global):
local_features, global_features = self.forward_features(x_local, x_global)
out = local_feature * global_feature
return out
class Topformer(nn.Layer):
def __init__(self,
cfgs,
channels,
out_channels,
embed_out_indice,
decode_out_indices=[1, 2, 3],
depth=4,
key_dim=16,
num_heads=8,
attn_ratio=2,
mlp_ratio=2,
c2t_stride=2,
droppath=0.,
injection_type="muli_sum",
injection=True,
num_classes=1000):
super().__init__()
self.channels = channels
self.injection = injection
self.embed_dim = sum(channels)
self.decode_out_indices = decode_out_indices
self.tpm = TokenPyramidModule(cfgs=cfgs, out_indices=embed_out_indice)
self.ppa = PyramidPoolAgg(stride=c2t_stride)
self.trans = Transformer(embed_dim=self.embed_dim,
key_dim=key_dim,
num_heads=num_heads,
depth=depth,
mlp_ratio=mlp_ratio,
attn_ratio=attn_ratio,
dropout=0.,
attention_dropout=0.,
droppath=droppath)
self.sim = nn.LayerList()
sim_block_dict = {"fuse_sum": FuseBlockSum,
"fuse_multi": FuseBlockMulti,
"multi_sum": InjectionMultiSum,
"multi_sim_cbr": InjectionMultiSumCBR}
sim_block = sim_block_dict[injection_type]
if self.injection:
for idx, (channel, out_channel) in enumerate(zip(channels, out_channels)):
if idx in decode_out_indices:
self.sim.append(sim_block(channel, out_channel))
else:
self.sim.append(Identity())
# classifer
self.avg_pool = nn.AdaptiveAvgPool2D(1)
self.head = nn.Sequential(
('bn', nn.BatchNorm1D(self.embed_dim)),
('l', nn.Linear(self.embed_dim, num_classes)),
)
def forward_features(self, x):
outputs = self.tpm(x)
out = self.ppa(outputs)
out = self.trans(out)
if self.injection:
xx = out.split(self.channels, axis=1) # self.channels is a list
results = []
for i in range(len(self.channels)):
if i in self.decode_out_indices:
local_tokens = outputs[i]
global_semantics = xx[i]
out_ = self.sim[i](local_tokens, global_semantics)
results.append(out_)
return results
else:
outputs.append(out)
return outputs
def forward(self, x):
x = self.forward_features(x)
x = self.avg_pool(x[-1]).squeeze([-2, -1])
x = self.head(x)
return x
def build_topformer(config):
"""Build TopFormer by reading options in config object
Args:
config: config instance contains setting options
Returns:
model: nn.Layer, TopFormer model
"""
model = Topformer(cfgs=config.MODEL.CFGS,
channels=config.MODEL.CHANNELS,
out_channels=config.MODEL.OUT_CHANNELS,
embed_out_indice=config.MODEL.EMBED_OUT_INDICE,
decode_out_indices=config.MODEL.DECODE_OUT_INDICES,
depth=config.MODEL.DEPTH,
key_dim=config.MODEL.KEY_DIM,
num_heads=config.MODEL.NUM_HEADS,
attn_ratio=config.MODEL.ATTN_RATIO,
mlp_ratio=config.MODEL.MLP_RATIO,
c2t_stride=config.MODEL.C2T_STRIDE,
droppath=config.MODEL.DROPPATH,
injection_type=config.MODEL.INJECTION_TYPE,
injection=config.MODEL.INJECTION,
num_classes=config.MODEL.NUM_CLASSES)
return model