forked from brianchmiel/TransformCodingInference
-
Notifications
You must be signed in to change notification settings - Fork 0
/
absorbe_bn.py
71 lines (55 loc) · 2.2 KB
/
absorbe_bn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
import torch.nn as nn
def remove_bn_params(bn_module):
bn_module.register_buffer('running_mean', None)
bn_module.register_buffer('running_var', None)
bn_module.register_parameter('weight', None)
bn_module.register_parameter('bias', None)
def init_bn_params(bn_module):
bn_module.running_mean.fill_(0)
bn_module.running_var.fill_(1)
if bn_module.affine:
bn_module.weight.fill_(1)
bn_module.bias.fill_(0)
def absorb_bn(module, bn_module, remove_bn=True, verbose=False):
with torch.no_grad():
w = module.weight
if module.bias is None:
zeros = torch.zeros(module.out_channels,
dtype=w.dtype, device=w.device)
bias = nn.Parameter(zeros)
module.register_parameter('bias', bias)
b = module.bias
if hasattr(bn_module, 'running_mean'):
b.add_(-bn_module.running_mean)
if hasattr(bn_module, 'running_var'):
invstd = bn_module.running_var.clone().add_(bn_module.eps).pow_(-0.5)
w.mul_(invstd.view(w.size(0), 1, 1, 1))
b.mul_(invstd)
if hasattr(bn_module, 'weight'):
w.mul_(bn_module.weight.view(w.size(0), 1, 1, 1))
b.mul_(bn_module.weight)
if hasattr(bn_module, 'bias'):
b.add_(bn_module.bias)
if remove_bn:
remove_bn_params(bn_module)
else:
init_bn_params(bn_module)
if verbose:
print('BN module %s was asborbed into layer %s' %
(bn_module, module))
def is_bn(m):
return isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d)
def is_absorbing(m):
return isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear)
def search_absorbe_bn(model, prev=None, remove_bn=False, verbose=False):
with torch.no_grad():
for m in model.children():
if is_bn(m):
if is_absorbing(prev):
absorb_bn(prev, m, remove_bn=remove_bn, verbose=verbose)
m.absorbed = True
else:
m.absorbed = False
search_absorbe_bn(m, remove_bn=remove_bn, verbose=verbose)
prev = m