-
Notifications
You must be signed in to change notification settings - Fork 17
/
StepLSTM.lua
279 lines (233 loc) · 11.1 KB
/
StepLSTM.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
-- StepLSTM is a step-wise module that can be used inside Recurrence to implement an LSTM.
-- That is, the StepLSTM efficiently implements a single LSTM time-step.
-- Its efficient because it doesn't use any internal modules; it calls BLAS directly.
-- StepLSTM is based on SeqLSTM.
-- Input : {input[t], hidden[t-1], cell[t-1])}
-- Output: {hidden[t], cell[t]}
local StepLSTM, parent = torch.class('nn.StepLSTM', 'nn.Module')
function StepLSTM:__init(inputsize, hiddensize, outputsize)
parent.__init(self)
if hiddensize and outputsize then
-- implements LSTMP
self.weightO = torch.Tensor(hiddensize, outputsize)
self.gradWeightO = torch.Tensor(hiddensize, outputsize)
else
-- implements LSTM
assert(inputsize and hiddensize and not outputsize)
outputsize = hiddensize
end
self.inputsize, self.hiddensize, self.outputsize = inputsize, hiddensize, outputsize
self.weight = torch.Tensor(inputsize+outputsize, 4 * hiddensize)
self.gradWeight = torch.Tensor(inputsize+outputsize, 4 * hiddensize)
self.bias = torch.Tensor(4 * hiddensize)
self.gradBias = torch.Tensor(4 * hiddensize):zero()
self:reset()
self.gates = torch.Tensor() -- batchsize x 4*outputsize
self.output = {torch.Tensor(), torch.Tensor()}
self.gradInput = {torch.Tensor(), torch.Tensor(), torch.Tensor()}
-- set this to true for variable length sequences that seperate
-- independent sequences with a step of zeros (a tensor of size D)
self.maskzero = false
self.v2 = true
end
function StepLSTM:reset(std)
self.bias:zero()
self.bias[{{self.outputsize + 1, 2 * self.outputsize}}]:fill(1)
self.weight:normal(0, std or (1.0 / math.sqrt(self.hiddensize + self.inputsize)))
if self.weightO then
self.weightO:normal(0, std or (1.0 / math.sqrt(self.outputsize + self.hiddensize)))
end
return self
end
function StepLSTM:updateOutput(input)
self.recompute_backward = true
local cur_x, prev_h, prev_c = input[1], input[2], input[3]
local next_h, next_c = self.output[1], self.output[2]
if cur_x.nn and cur_x.nn.StepLSTM_updateOutput and not self.forceLua then
if self.weightO then -- LSTMP
self.hidden = self.hidden or cur_x.new()
cur_x.nn.StepLSTM_updateOutput(self.weight, self.bias, self.gates,
cur_x, prev_h, prev_c,
self.inputsize, self.hiddensize, self.outputsize,
self.hidden, next_c, self.weightO, next_h)
else -- LSTM
cur_x.nn.StepLSTM_updateOutput(self.weight, self.bias, self.gates,
cur_x, prev_h, prev_c,
self.inputsize, self.hiddensize, self.outputsize,
next_h, next_c)
end
else
if self.weightO then -- LSTMP
self.hidden = self.hidden or cur_x.new()
next_h = self.hidden
end
assert(torch.isTensor(prev_h))
assert(torch.isTensor(prev_c))
local batchsize, inputsize, hiddensize = cur_x:size(1), cur_x:size(2), self.hiddensize
assert(inputsize == self.inputsize)
-- TODO use self.bias_view
local bias_expand = self.bias:view(1, 4 * hiddensize):expand(batchsize, 4 * hiddensize)
local Wx = self.weight:narrow(1,1,inputsize)
local Wh = self.weight:narrow(1,inputsize+1,self.outputsize)
next_h:resize(batchsize, hiddensize)
next_c:resize(batchsize, hiddensize)
local gates = self.gates
local nElement = gates:nElement()
gates:resize(batchsize, 4 * hiddensize)
if gates:nElement() ~= batchsize * 4 * hiddensize then
gates:zero()
end
-- forward
gates:addmm(bias_expand, cur_x, Wx)
gates:addmm(prev_h, Wh)
gates[{{}, {1, 3 * hiddensize}}]:sigmoid()
gates[{{}, {3 * hiddensize + 1, 4 * hiddensize}}]:tanh()
local input_gate = gates[{{}, {1, hiddensize}}]
local forget_gate = gates[{{}, {hiddensize + 1, 2 * hiddensize}}]
local output_gate = gates[{{}, {2 * hiddensize + 1, 3 * hiddensize}}]
local input_transform = gates[{{}, {3 * hiddensize + 1, 4 * hiddensize}}]
next_h:cmul(input_gate, input_transform)
next_c:cmul(forget_gate, prev_c):add(next_h)
next_h:tanh(next_c):cmul(output_gate)
if self.weightO then -- LSTMP
self.output[1]:resize(batchsize, self.outputsize)
self.output[1]:mm(next_h, self.weightO)
end
end
if self.maskzero and self.zeroMask ~= false then
if self.v2 then
assert(self.zeroMask ~= nil, torch.type(self).." expecting zeroMask tensor or false")
else -- backwards compat
self.zeroMask = nn.utils.getZeroMaskBatch(cur_x, self.zeroMask)
end
-- zero masked outputs and gates
nn.utils.recursiveZeroMask({next_h, next_c, self.gates}, self.zeroMask)
end
return self.output
end
function StepLSTM:backward(input, gradOutput, scale)
self.recompute_backward = false
local cur_x, prev_h, prev_c = input[1], input[2], input[3]
local grad_next_h, grad_next_c = gradOutput[1], gradOutput[2]
local next_c = self.output[2]
local grad_cur_x, grad_prev_h, grad_prev_c = self.gradInput[1], self.gradInput[2], self.gradInput[3]
scale = scale or 1.0
assert(scale == 1.0, 'must have scale=1')
local grad_gates = torch.getBuffer('StepLSTM', 'grad_gates', self.gates) -- batchsize x 4*outputsize
local grad_gates_sum = torch.getBuffer('StepLSTM', 'grad_gates_sum', self.gates) -- 1 x 4*outputsize
if self.maskzero and self.zeroMask ~= false then
-- zero masked gradOutput
nn.utils.recursiveZeroMask({grad_next_h, grad_next_c}, self.zeroMask)
end
if cur_x.nn and cur_x.nn.StepLSTM_backward and not self.forceLua then
if self.weightO then -- LSTMP
local grad_hidden = torch.getBuffer('StepLSTM', 'grad_hidden', self.hidden)
cur_x.nn.StepLSTM_backward(self.weight, self.gates,
self.gradWeight, self.gradBias, grad_gates, grad_gates_sum,
cur_x, prev_h, prev_c, next_c, grad_next_h, grad_next_c,
scale, self.inputsize, self.hiddensize, self.outputsize,
grad_cur_x, grad_prev_h, grad_prev_c,
self.weightO, self.hidden, self.gradWeightO, grad_hidden)
else -- LSTM
cur_x.nn.StepLSTM_backward(self.weight, self.gates,
self.gradWeight, self.gradBias, grad_gates, grad_gates_sum,
cur_x, prev_h, prev_c, next_c, grad_next_h, grad_next_c,
scale, self.inputsize, self.hiddensize, self.outputsize,
grad_cur_x, grad_prev_h, grad_prev_c)
end
else
local batchsize, inputsize, hiddensize = cur_x:size(1), cur_x:size(2), self.hiddensize
assert(inputsize == self.inputsize)
if self.weightO then -- LSTMP
local grad_hidden = torch.getBuffer('StepLSTM', 'grad_hidden', self.hidden)
self.gradWeightO:addmm(scale, self.hidden:t(), grad_next_h)
grad_hidden:resize(batchsize, hiddensize)
grad_hidden:mm(grad_next_h, self.weightO:t())
grad_next_h = grad_hidden
end
grad_cur_x:resize(batchsize, inputsize)
grad_prev_h:resize(batchsize, self.outputsize)
grad_prev_c:resize(batchsize, hiddensize)
local Wx = self.weight:narrow(1,1,inputsize)
local Wh = self.weight:narrow(1,inputsize+1,self.outputsize)
local grad_Wx = self.gradWeight:narrow(1,1,inputsize)
local grad_Wh = self.gradWeight:narrow(1,inputsize+1,self.outputsize)
local grad_b = self.gradBias
local gates = self.gates
-- backward
local input_gate = gates[{{}, {1, hiddensize}}]
local forget_gate = gates[{{}, {hiddensize + 1, 2 * hiddensize}}]
local output_gate = gates[{{}, {2 * hiddensize + 1, 3 * hiddensize}}]
local input_transform = gates[{{}, {3 * hiddensize + 1, 4 * hiddensize}}]
grad_gates:resize(batchsize, 4 * hiddensize)
local grad_input_gate = grad_gates[{{}, {1, hiddensize}}]
local grad_forget_gate = grad_gates[{{}, {hiddensize + 1, 2 * hiddensize}}]
local grad_output_gate = grad_gates[{{}, {2 * hiddensize + 1, 3 * hiddensize}}]
local grad_input_transform = grad_gates[{{}, {3 * hiddensize + 1, 4 * hiddensize}}]
-- we use grad_[input,forget,output]_gate as temporary buffers to compute grad_prev_c.
grad_input_gate:tanh(next_c)
grad_forget_gate:cmul(grad_input_gate, grad_input_gate)
grad_output_gate:fill(1):add(-1, grad_forget_gate):cmul(output_gate):cmul(grad_next_h)
grad_prev_c:add(grad_next_c, grad_output_gate)
-- we use above grad_input_gate to compute grad_output_gate
grad_output_gate:fill(1):add(-1, output_gate):cmul(output_gate):cmul(grad_input_gate):cmul(grad_next_h)
-- Use grad_input_gate as a temporary buffer for computing grad_input_transform
grad_input_gate:cmul(input_transform, input_transform)
grad_input_transform:fill(1):add(-1, grad_input_gate):cmul(input_gate):cmul(grad_prev_c)
-- We don't need any temporary storage for these so do them last
grad_input_gate:fill(1):add(-1, input_gate):cmul(input_gate):cmul(input_transform):cmul(grad_prev_c)
grad_forget_gate:fill(1):add(-1, forget_gate):cmul(forget_gate):cmul(prev_c):cmul(grad_prev_c)
grad_cur_x:mm(grad_gates, Wx:t())
grad_Wx:addmm(scale, cur_x:t(), grad_gates)
grad_Wh:addmm(scale, prev_h:t(), grad_gates)
grad_gates_sum:resize(1, 4 * hiddensize):sum(grad_gates, 1)
grad_b:add(scale, grad_gates_sum)
grad_prev_h:mm(grad_gates, Wh:t())
grad_prev_c:cmul(forget_gate)
end
return self.gradInput
end
function StepLSTM:updateGradInput(input, gradOutput)
if self.recompute_backward then
self:backward(input, gradOutput, 1.0)
end
return self.gradInput
end
function StepLSTM:accGradParameters(input, gradOutput, scale)
if self.recompute_backward then
self:backward(input, gradOutput, scale)
end
end
function StepLSTM:clearState()
self.gates:set()
self.output[1]:set(); self.output[2]:set()
self.gradInput[1]:set(); self.gradInput[2]:set(); self.gradInput[3]:set()
end
function StepLSTM:type(type, ...)
self:clearState()
return parent.type(self, type, ...)
end
function StepLSTM:parameters()
return {self.weight, self.bias, self.weightO}, {self.gradWeight, self.gradBias, self.gradWeightO}
end
function StepLSTM:maskZero(v1)
self.maskzero = true
self.v2 = not v1
return self
end
StepLSTM.setZeroMask = nn.MaskZero.setZeroMask
function StepLSTM:__tostring__()
if self.weightO then
return self.__typename .. string.format("(%d -> %d -> %d)", self.inputsize, self.hiddensize, self.outputsize)
else
return self.__typename .. string.format("(%d -> %d)", self.inputsize, self.outputsize)
end
end
-- for sharedClone
local _ = require 'moses'
local params = _.clone(parent.dpnn_parameters)
table.insert(params, 'weightO')
StepLSTM.dpnn_parameters = params
local gradParams = _.clone(parent.dpnn_gradParameters)
table.insert(gradParams, 'gradWeightO')
StepLSTM.dpnn_gradParameters = gradParams