-
Notifications
You must be signed in to change notification settings - Fork 435
/
example_pspnet.py
161 lines (139 loc) · 6.22 KB
/
example_pspnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import torch
from torch import nn
import torch.nn.functional as F
from repvgg import get_RepVGG_func_by_name
# The PSPNet parts are from
# https://github.com/hszhao/semseg
class PPM(nn.Module):
def __init__(self, in_dim, reduction_dim, bins, BatchNorm):
super(PPM, self).__init__()
self.features = []
for bin in bins:
self.features.append(nn.Sequential(
nn.AdaptiveAvgPool2d(bin),
nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False),
BatchNorm(reduction_dim),
nn.ReLU(inplace=True)
))
self.features = nn.ModuleList(self.features)
def forward(self, x):
x_size = x.size()
out = [x]
for f in self.features:
out.append(F.interpolate(f(x), x_size[2:], mode='bilinear', align_corners=True))
return torch.cat(out, 1)
class PSPNet(nn.Module):
def __init__(self,
backbone_name, backbone_file, deploy,
bins=(1, 2, 3, 6), dropout=0.1, classes=2,
zoom_factor=8, use_ppm=True, criterion=nn.CrossEntropyLoss(ignore_index=255), BatchNorm=nn.BatchNorm2d,
pretrained=True):
super(PSPNet, self).__init__()
assert 2048 % len(bins) == 0
assert classes > 1
assert zoom_factor in [1, 2, 4, 8]
self.zoom_factor = zoom_factor
self.use_ppm = use_ppm
self.criterion = criterion
repvgg_fn = get_RepVGG_func_by_name(backbone_name)
backbone = repvgg_fn(deploy)
if pretrained:
checkpoint = torch.load(backbone_file)
if 'state_dict' in checkpoint:
checkpoint = checkpoint['state_dict']
ckpt = {k.replace('module.', ''): v for k, v in checkpoint.items()} # strip the names
backbone.load_state_dict(ckpt)
self.layer0, self.layer1, self.layer2, self.layer3, self.layer4 = backbone.stage0, backbone.stage1, backbone.stage2, backbone.stage3, backbone.stage4
# The last two stages should have stride=1 for semantic segmentation
# Note that the stride of 1x1 should be the same as the 3x3
# Use dilation following the implementation of PSPNet
secondlast_channel = 0
for n, m in self.layer3.named_modules():
if ('rbr_dense' in n or 'rbr_reparam' in n) and isinstance(m, nn.Conv2d):
m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)
print('change dilation, padding, stride of ', n)
secondlast_channel = m.out_channels
elif 'rbr_1x1' in n and isinstance(m, nn.Conv2d):
m.stride = (1, 1)
print('change stride of ', n)
last_channel = 0
for n, m in self.layer4.named_modules():
if ('rbr_dense' in n or 'rbr_reparam' in n) and isinstance(m, nn.Conv2d):
m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1)
print('change dilation, padding, stride of ', n)
last_channel = m.out_channels
elif 'rbr_1x1' in n and isinstance(m, nn.Conv2d):
m.stride = (1, 1)
print('change stride of ', n)
fea_dim = last_channel
aux_in = secondlast_channel
if use_ppm:
self.ppm = PPM(fea_dim, int(fea_dim/len(bins)), bins, BatchNorm)
fea_dim *= 2
self.cls = nn.Sequential(
nn.Conv2d(fea_dim, 512, kernel_size=3, padding=1, bias=False),
BatchNorm(512),
nn.ReLU(inplace=True),
nn.Dropout2d(p=dropout),
nn.Conv2d(512, classes, kernel_size=1)
)
if self.training:
self.aux = nn.Sequential(
nn.Conv2d(aux_in, 256, kernel_size=3, padding=1, bias=False),
BatchNorm(256),
nn.ReLU(inplace=True),
nn.Dropout2d(p=dropout),
nn.Conv2d(256, classes, kernel_size=1)
)
def forward(self, x, y=None):
x_size = x.size()
assert (x_size[2]-1) % 8 == 0 and (x_size[3]-1) % 8 == 0
h = int((x_size[2] - 1) / 8 * self.zoom_factor + 1)
w = int((x_size[3] - 1) / 8 * self.zoom_factor + 1)
x = self.layer0(x)
x = self.layer1(x)
x = self.layer2(x)
x_tmp = self.layer3(x)
x = self.layer4(x_tmp)
if self.use_ppm:
x = self.ppm(x)
x = self.cls(x)
if self.zoom_factor != 1:
x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True)
if self.training:
aux = self.aux(x_tmp)
if self.zoom_factor != 1:
aux = F.interpolate(aux, size=(h, w), mode='bilinear', align_corners=True)
main_loss = self.criterion(x, y)
aux_loss = self.criterion(aux, y)
return x.max(1)[1], main_loss, aux_loss
else:
return x
if __name__ == '__main__':
# 1. Build the PSPNet with RepVGG backbone. Download the ImageNet-pretrained weight file and load it.
model = PSPNet(backbone_name='RepVGG-A0', backbone_file='RepVGG-A0-train.pth', deploy=False, classes=19, pretrained=True)
# 2. Train it
# seg_train(model)
# 3. Convert and check the equivalence
input = torch.rand(4, 3, 713, 713)
model.eval()
print(model)
y_train = model(input)
for module in model.modules():
if hasattr(module, 'switch_to_deploy'):
module.switch_to_deploy()
y_deploy = model(input)
print('output is ', y_deploy.size())
print('=================== The diff is')
print(((y_deploy - y_train) ** 2).sum())
# 4. Save the converted model
torch.save(model.state_dict(), 'PSPNet-RepVGG-A0-deploy.pth')
del model # Or do whatever you want with it
# 5. For inference, load the saved model. There is no need to load the ImageNet-pretrained weights again.
deploy_model = PSPNet(backbone_name='RepVGG-A0', backbone_file=None, deploy=True, classes=19, pretrained=False)
deploy_model.eval()
deploy_model.load_state_dict(torch.load('PSPNet-RepVGG-A0-deploy.pth'))
# 6. Check again or do whatever you want
y_deploy = deploy_model(input)
print('=================== The diff is')
print(((y_deploy - y_train) ** 2).sum())