forked from open-mmlab/mmsegmentation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathddrnet_23_in1k-pre_2xb6-120k_cityscapes-1024x1024.py
93 lines (87 loc) · 2.8 KB
/
ddrnet_23_in1k-pre_2xb6-120k_cityscapes-1024x1024.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
_base_ = [
'../_base_/datasets/cityscapes_1024x1024.py',
'../_base_/default_runtime.py',
]
# The class_weight is borrowed from https://github.com/openseg-group/OCNet.pytorch/issues/14 # noqa
# Licensed under the MIT License
class_weight = [
0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754, 1.0489, 0.8786,
1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 1.0865, 1.0955, 1.0865, 1.1529,
1.0507
]
checkpoint = 'https://download.openmmlab.com/mmsegmentation/v0.5/ddrnet/pretrain/ddrnet23-in1kpre_3rdparty-9ca29f62.pth' # noqa
crop_size = (1024, 1024)
data_preprocessor = dict(
type='SegDataPreProcessor',
size=crop_size,
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_val=0,
seg_pad_val=255)
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
data_preprocessor=data_preprocessor,
backbone=dict(
type='DDRNet',
in_channels=3,
channels=64,
ppm_channels=128,
norm_cfg=norm_cfg,
align_corners=False,
init_cfg=dict(type='Pretrained', checkpoint=checkpoint)),
decode_head=dict(
type='DDRHead',
in_channels=64 * 4,
channels=128,
dropout_ratio=0.,
num_classes=19,
align_corners=False,
norm_cfg=norm_cfg,
loss_decode=[
dict(
type='OhemCrossEntropy',
thres=0.9,
min_kept=131072,
class_weight=class_weight,
loss_weight=1.0),
dict(
type='OhemCrossEntropy',
thres=0.9,
min_kept=131072,
class_weight=class_weight,
loss_weight=0.4),
]),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
train_dataloader = dict(batch_size=6, num_workers=4)
iters = 120000
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer, clip_grad=None)
# learning policy
param_scheduler = [
dict(
type='PolyLR',
eta_min=0,
power=0.9,
begin=0,
end=iters,
by_epoch=False)
]
# training schedule for 120k
train_cfg = dict(
type='IterBasedTrainLoop', max_iters=iters, val_interval=iters // 10)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
default_hooks = dict(
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(
type='CheckpointHook', by_epoch=False, interval=iters // 10),
sampler_seed=dict(type='DistSamplerSeedHook'),
visualization=dict(type='SegVisualizationHook'))
randomness = dict(seed=304)