-
Notifications
You must be signed in to change notification settings - Fork 314
/
edvr_model.py
71 lines (61 loc) · 2.69 KB
/
edvr_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import logging
import torch
from torch.nn.parallel import DistributedDataParallel
from basicsr.models.video_base_model import VideoBaseModel
logger = logging.getLogger('basicsr')
class EDVRModel(VideoBaseModel):
"""EDVR Model.
Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks. # noqa: E501
"""
def __init__(self, opt):
super(EDVRModel, self).__init__(opt)
if self.is_train:
self.train_tsa_iter = opt['train'].get('tsa_iter')
def setup_optimizers(self):
train_opt = self.opt['train']
dcn_lr_mul = train_opt.get('dcn_lr_mul', 1)
logger.info(f'Multiple the learning rate for dcn with {dcn_lr_mul}.')
if dcn_lr_mul == 1:
optim_params = self.net_g.parameters()
else: # separate dcn params and normal params for differnet lr
normal_params = []
dcn_params = []
for name, param in self.net_g.named_parameters():
if 'dcn' in name:
dcn_params.append(param)
else:
normal_params.append(param)
optim_params = [
{ # add normal params first
'params': normal_params,
'lr': train_opt['optim_g']['lr']
},
{
'params': dcn_params,
'lr': train_opt['optim_g']['lr'] * dcn_lr_mul
},
]
optim_type = train_opt['optim_g'].pop('type')
if optim_type == 'Adam':
self.optimizer_g = torch.optim.Adam(optim_params,
**train_opt['optim_g'])
else:
raise NotImplementedError(
f'optimizer {optim_type} is not supperted yet.')
self.optimizers.append(self.optimizer_g)
def optimize_parameters(self, current_iter):
if self.train_tsa_iter:
if current_iter == 1:
logger.info(
f'Only train TSA module for {self.train_tsa_iter} iters.')
for name, param in self.net_g.named_parameters():
if 'fusion' not in name:
param.requires_grad = False
elif current_iter == self.train_tsa_iter:
logger.warning('Train all the parameters.')
for param in self.net_g.parameters():
param.requires_grad = True
if isinstance(self.net_g, DistributedDataParallel):
logger.warning('Set net_g.find_unused_parameters = False.')
self.net_g.find_unused_parameters = False
super(VideoBaseModel, self).optimize_parameters(current_iter)