-
Notifications
You must be signed in to change notification settings - Fork 305
/
Copy pathffn.py
60 lines (52 loc) · 1.75 KB
/
ffn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
"""Implementation of FFN block in the style of Transformers."""
from functools import partial
from torch import nn
from src.models.sequence.base import SequenceModule
from src.models.nn import LinearActivation, DropoutNd
class FFN(SequenceModule):
def __init__(
self,
d_input,
expand=2,
d_output=None,
transposed=False,
activation='gelu',
initializer=None,
dropout=0.0,
tie_dropout=False,
):
super().__init__()
self.d_output = d_input if d_output is None else d_output
self.transposed = transposed
d_inner = int(expand * d_input)
linear1 = LinearActivation(
d_input, d_inner,
transposed=transposed,
activation=activation,
initializer=initializer,
activate=True,
)
dropout_cls = partial(DropoutNd, transposed=self.transposed) if tie_dropout else nn.Dropout
# dropout_cls = nn.Dropout2d if self.transposed else nn.Dropout
drop = dropout_cls(dropout) if dropout > 0.0 else nn.Identity()
linear2 = LinearActivation(
d_inner, self.d_output,
transposed=transposed,
activation=None,
initializer=initializer,
activate=False,
)
self.ff = nn.Sequential(
linear1,
drop,
linear2,
)
def forward(self, x, *args, **kwargs):
return self.ff(x), None
def step(self, x, state, **kwargs):
# x: [batch, d_input]
if self.transposed:
# expects: [batch, d_input, seq_len]
return self.ff(x.unsqueeze(-1)).squeeze(-1), state
else:
return self.ff(x), state