forked from rabeehk/vibert
-
Notifications
You must be signed in to change notification settings - Fork 0
/
mixout.py
109 lines (92 loc) · 4.01 KB
/
mixout.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# codes are from Implementation of mixout from https://arxiv.org/abs/1909.11299
import math
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
from torch.nn import Parameter
from torch.autograd.function import InplaceFunction
class Mixout(InplaceFunction):
# target: a weight tensor mixes with a input tensor
# A forward method returns
# [(1 - Bernoulli(1 - p) mask) * target + (Bernoulli(1 - p) mask) * input - p * target]/(1 - p)
# where p is a mix probability of mixout.
# A backward returns the gradient of the forward method.
# Dropout is equivalent to the case of target=None.
# I modified the code of dropout in PyTorch.
@staticmethod
def _make_noise(input):
return input.new().resize_as_(input)
@classmethod
def forward(cls, ctx, input, target=None, p=0.0, training=False, inplace=False):
if p < 0 or p > 1:
raise ValueError("A mix probability of mixout has to be between 0 and 1," " but got {}".format(p))
if target is not None and input.size() != target.size():
raise ValueError(
"A target tensor size must match with a input tensor size {},"
" but got {}".format(input.size(), target.size())
)
ctx.p = p
ctx.training = training
if ctx.p == 0 or not ctx.training:
return input
if target is None:
target = cls._make_noise(input)
target.fill_(0)
target = target.to(input.device)
if inplace:
ctx.mark_dirty(input)
output = input
else:
output = input.clone()
ctx.noise = cls._make_noise(input)
if len(ctx.noise.size()) == 1:
ctx.noise.bernoulli_(1 - ctx.p)
else:
ctx.noise[0].bernoulli_(1 - ctx.p)
ctx.noise = ctx.noise[0].repeat(input.size()[0], 1)
ctx.noise.expand_as(input)
if ctx.p == 1:
output = target
else:
output = ((1 - ctx.noise) * target + ctx.noise * output - ctx.p * target) / (1 - ctx.p)
return output
@staticmethod
def backward(ctx, grad_output):
if ctx.p > 0 and ctx.training:
return grad_output * ctx.noise, None, None, None, None
else:
return grad_output, None, None, None, None
def mixout(input, target=None, p=0.0, training=False, inplace=False):
return Mixout.apply(input, target, p, training, inplace)
class MixLinear(torch.nn.Module):
__constants__ = ["bias", "in_features", "out_features"]
# If target is None, nn.Sequential(nn.Linear(m, n), MixLinear(m', n', p))
# is equivalent to nn.Sequential(nn.Linear(m, n), nn.Dropout(p), nn.Linear(m', n')).
# If you want to change a dropout layer to a mixout layer,
# you should replace nn.Linear right after nn.Dropout(p) with Mixout(p)
def __init__(self, in_features, out_features, bias=True, target=None, p=0.0):
super(MixLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(torch.Tensor(out_features, in_features))
if bias:
self.bias = Parameter(torch.Tensor(out_features))
else:
self.register_parameter("bias", None)
self.reset_parameters()
self.target = target
self.p = p
def reset_parameters(self):
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
def forward(self, input):
return F.linear(input, mixout(self.weight, self.target, self.p, self.training), self.bias)
def extra_repr(self):
type = "drop" if self.target is None else "mix"
return "{}={}, in_features={}, out_features={}, bias={}".format(
type + "out", self.p, self.in_features, self.out_features, self.bias is not None
)