-
Notifications
You must be signed in to change notification settings - Fork 73
/
diracconv.py
88 lines (64 loc) · 3.42 KB
/
diracconv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.init import dirac_
def normalize(w):
"""Normalizes weight tensor over full filter."""
return F.normalize(w.view(w.shape[0], -1)).view_as(w)
class DiracConv(nn.Module):
def init_params(self, out_channels):
self.alpha = nn.Parameter(torch.Tensor(out_channels).fill_(1))
self.beta = nn.Parameter(torch.Tensor(out_channels).fill_(0.1))
self.register_buffer('delta', dirac_(self.weight.data.clone()))
assert self.delta.shape == self.weight.shape
self.v = (-1,) + (1,) * (self.weight.dim() - 1)
def transform_weight(self):
return self.alpha.view(*self.v) * self.delta + self.beta.view(*self.v) * normalize(self.weight)
class DiracConv1d(nn.Conv1d, DiracConv):
"""Dirac parametrized convolutional layer.
Works the same way as `nn.Conv1d`, but has additional weight parametrizatoin:
:math:`\alpha\delta + \beta W`,
where:
:math:`\alpha` and :math:`\beta` are learnable scalars,
:math:`\delta` is such a tensor so that `F.conv1d(x, delta) = x`, ie
Kroneker delta
`W` is weight tensor
It is user's responsibility to set correcting padding. Only stride=1 supported.
"""
def __init__(self, in_channels, out_channels, kernel_size, padding=0, dilation=1, bias=True):
super().__init__(in_channels, out_channels, kernel_size, stride=1, padding=padding, dilation=dilation, bias=bias)
self.init_params(out_channels)
def forward(self, input):
return F.conv1d(input, self.transform_weight(), self.bias, self.stride, self.padding, self.dilation)
class DiracConv2d(nn.Conv2d, DiracConv):
"""Dirac parametrized convolutional layer.
Works the same way as `nn.Conv2d`, but has additional weight parametrizatoin:
:math:`\alpha\delta + \beta W`,
where:
:math:`\alpha` and :math:`\beta` are learnable scalars,
:math:`\delta` is such a tensor so that `F.conv2d(x, delta) = x`, ie
Kroneker delta
`W` is weight tensor
It is user's responsibility to set correcting padding. Only stride=1 supported.
"""
def __init__(self, in_channels, out_channels, kernel_size, padding=0, dilation=1, bias=True):
super().__init__(in_channels, out_channels, kernel_size, stride=1, padding=padding, dilation=dilation, bias=bias)
self.init_params(out_channels)
def forward(self, input):
return F.conv2d(input, self.transform_weight(), self.bias, self.stride, self.padding, self.dilation)
class DiracConv3d(nn.Conv3d, DiracConv):
"""Dirac parametrized convolutional layer.
Works the same way as `nn.Conv3d`, but has additional weight parametrizatoin:
:math:`\alpha\delta + \beta W`,
where:
:math:`\alpha` and :math:`\beta` are learnable scalars,
:math:`\delta` is such a tensor so that `F.conv3d(x, delta) = x`, ie
Kroneker delta
`W` is weight tensor
It is user's responsibility to set correcting padding. Only stride=1 supported.
"""
def __init__(self, in_channels, out_channels, kernel_size, padding=0, dilation=1, bias=True):
super().__init__(in_channels, out_channels, kernel_size, stride=1, padding=padding, dilation=dilation, bias=bias)
self.init_params(out_channels)
def forward(self, input):
return F.conv3d(input, self.transform_weight(), self.bias, self.stride, self.padding, self.dilation)