forked from jpuigcerver/Laia
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlaia-train-ctc
executable file
·408 lines (373 loc) · 15 KB
/
laia-train-ctc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
#!/usr/bin/env th
require 'laia'
require 'optim'
local ctc_trainer = laia.CTCTrainer()
local train_batcher = laia.RandomBatcher()
local valid_batcher = laia.RandomBatcher()
local distorter = laia.ImageDistorter and laia.ImageDistorter()
local weight_decay_regularizer = laia.WeightDecayRegularizer()
local adversarial_regularizer = laia.AdversarialRegularizer()
local epoch_summarizer = laia.EpochSummarizer()
local progress_table = laia.ProgressTable()
local parser = laia.argparse(){
name = 'laia-train-ctc',
description = ''
}
-- Register laia.Version options
laia.Version():registerOptions(parser)
-- Register laia.log options
laia.log.registerOptions(parser)
-- Register cudnn options, only if available
if cudnn then cudnn.registerOptions(parser, true) end
-- Register batcher options (only for train_batcher, valid_batcher will use the
-- same options).
train_batcher:registerOptions(parser)
-- CTC training options.
ctc_trainer:registerOptions(parser)
-- Register regularizers options.
weight_decay_regularizer:registerOptions(parser, true)
adversarial_regularizer:registerOptions(parser, true)
-- Register distorter options.
if distorter then distorter:registerOptions(parser, true) end
epoch_summarizer:registerOptions(parser, true)
-- Progress Table options
progress_table:registerOptions(parser, true)
-- Memory monitor options
laia.mem.registerOptions(parser, true)
-- train_ctc arguments
parser:argument('checkpoint', 'Input model or checkpoint for training.')
parser:argument('symbols_table', 'Table mapping from symbols to integer IDs.')
parser:argument('train_img', 'List of training image files.')
parser:argument('train_txt', 'Table of training image transcripts.')
parser:argument('valid_img', 'List of validation image files.'):args('?')
parser:argument('valid_txt', 'Table of validation image transcripts.'):args('?')
-- Custom options
parser:option(
'--seed -s', 'Seed for random numbers generation.',
0x012345, laia.toint)
parser:option(
'--gpu', 'If gpu>0, uses the specified GPU, otherwise uses the CPU.',
1, laia.toint)
parser:option(
'--continue_train', 'If true, continue training from the last state.',
false, laia.toboolean)
parser:option(
'--learning_rate', 'Initial learning rate.', 0.001, tonumber)
:gt(0.0)
parser:option(
'--learning_rate_decay',
'Learning rate decay after each epoch (1.0 means no decay).', 1.0, tonumber)
:ge(0.0):le(1.0)
parser:option(
'--learning_rate_decay_min',
'Apply learning rate decay until it reaches this minimum value.', 0.0, tonumber)
:gt(0.0)
parser:option(
'--learning_rate_decay_after',
'Start learning rate decay after this number of epochs.', 10, laia.toint)
:ge(1)
parser:option(
'--learning_rate_decay_period',
'Apply learning rate decay every this number of epochs.', 1, laia.toint)
:ge(1)
parser:option(
'--rmsprop_alpha', 'RMSProp smoothing parameter.', 0.95, tonumber)
:argname('<alpha>')
:gt(0.0):lt(1.0)
parser:option(
'--best_criterion', 'If not empty, use this criterion to choose the ' ..
'best model (e.g. for early stopping).', 'valid_cer',
{ train_loss = 'train_loss', train_cer = 'train_cer',
valid_loss = 'valid_loss', valid_cer = 'valid_cer' })
:argname('<criterion>')
parser:option(
'--max_epochs', 'If n>0, training will continue for, at most, n epochs.',
0, laia.toint)
:argname('<n>')
parser:option(
'--early_stop_epochs', 'If n>0, stop training after this number of ' ..
'epochs without a significant improvement, according to ' ..
'--best_criterion. If n=0, early stopping will not be used.',
0, laia.toint)
:argname('<n>')
parser:option(
'--early_stop_threshold', 'Minimum relative improvement threshold used ' ..
'for early stop. Relate improvement lower than this are not considered ' ..
'relevant; e.g.: 0.05 will consider significant improvements =>5%.',
0.0, tonumber)
:argname('<t>')
:ge(0.0)
parser:option(
'--checkpoint_save_interval', 'Save a checkpoint to disk on every n ' ..
'epochs. Note: regardless of this, every time a better model is found, ' ..
'a checkpoint is saved.', 50, laia.toint)
:argname('<n>')
:gt(0)
parser:option(
'--checkpoint_output', 'Save checkpoints to this file. If not given, ' ..
'the input checkpoint will be overwritten.', '')
:argname('<file>')
parser:option(
'--progress_table_output', 'Save the progress of training after each ' ..
'epoch to this text file. Useful for plotting and monitoring.', '')
:argname('<file>')
parser:option(
'--auto_width_factor', 'If true, sets the width factor for the batchers ' ..
'automatically, from the size of the pooling layers.',
false, laia.toboolean)
:argname('<bool>')
-- Parse from command line
local opts = parser:parse()
-- Start memory monitor after parsing the options
laia.mem.startMonitor()
-- Initialize random seeds
laia.manualSeed(opts.seed)
-- If validation specified, both images and transcripts are required.
if opts.valid_img and not opts.valid_txt then
laia.log.warn('Ignoring validation partition: a list of images was given, ' ..
'but not a list of transcripts.')
opts.valid_img = nil
end
local checkpoint = laia.Checkpoint():load(opts.checkpoint)
local model = checkpoint:Last():getModel()
assert(model ~= nil, 'No model was found in the checkpoint file!')
-- If a GPU is requested, check that we have everything necessary.
if opts.gpu > 0 then
assert(cutorch ~= nil, 'Package cutorch is required in order to use the GPU.')
assert(nn ~= nil, 'Package nn is required in order to use the GPU.')
cutorch.setDevice(opts.gpu)
model = model:cuda()
-- If cudnn_force_convert=true, force all possible layers to use cuDNN impl.
if cudnn and cudnn.force_convert then
laia.log.warn('Some layers in cuDNN are non-deterministic on the ' ..
'backward pass. If 100% reproducible experiments are ' ..
'required, use --cudnn_force_convert=false.')
cudnn.convert(model, cudnn)
end
else
-- This should not be necessary, but just in case
model = model:float()
end
-- Load Laia RNGState from the checkpoint. Notice that if a RNGState is
-- available in the chkpt, it will override the manual seed specified with
-- the --seed option.
laia.setRNGState(checkpoint:getRNGState())
-- Prepare batchers
if opts.auto_width_factor then
local width_factor = laia.getWidthFactor(model)
train_batcher:setOptions({width_factor = width_factor})
laia.log.info('Batcher width factor was automatically set to %d',
width_factor)
end
train_batcher:load(opts.train_img, opts.train_txt, opts.symbols_table)
if opts.valid_img and opts.valid_txt then
valid_batcher:setOptions(train_batcher:getOptions())
valid_batcher:load(opts.valid_img, opts.valid_txt, opts.symbols_table)
else
valid_batcher = nil
end
-- Prepare CTC trainer
ctc_trainer
:setModel(model)
:setTrainBatcher(train_batcher)
:setValidBatcher(valid_batcher)
:setDistorter(distorter)
:setWeightRegularizer(weight_decay_regularizer)
:setAdversarialRegularizer(adversarial_regularizer)
:setOptimizer(optim.rmsprop)
:start()
local epoch = 0
local rmsprop_opts = {
alpha = opts.rmsprop_alpha,
learningRate = opts.learning_rate
}
if opts.continue_train then
-- Continue training from the last epoch.
epoch = checkpoint:Last():getEpoch()
if checkpoint:getRMSPropState() then
rmsprop_opts = checkpoint:getRMSPropState()
end
checkpoint:setTrainConfig(opts)
else
-- Forget about the previous "best" model and results, if any.
checkpoint:Best():setModel(nil)
checkpoint:Best():setEpoch(0)
checkpoint:Best():addSummary('train', nil)
checkpoint:Best():addSummary('valid', nil)
checkpoint:setTrainConfig(opts)
end
-- If no validation data was passed but --best_criterion is supposed to use
-- validation data, change the --best_criterion to use training data instead
-- and report an ERROR to the user.
if (not opts.valid_img or not opts.valid_txt) then
local m = opts.best_criterion:match('^valid_(.+)$')
if m ~= nil then
local new_criterion = 'train_' .. m
laia.log.error('You are trying to use --best_criterion=%s but no ' ..
'validation data was provided. Criterion changed to: %q.',
opts.best_criterion, new_criterion)
opts.best_criterion = new_criterion
end
end
if checkpoint:getTrainConfig() ~= nil and
opts.best_criterion ~= checkpoint:getTrainConfig().best_criterion then
laia.log.warn('Current --best_criterion does not match the one used to ' ..
'create the input checkpoint.')
end
-- At the end, get_criterion_value is a function that returns the value used
-- to choose the "best" model, according to --best_criterion and the
-- train/valid summaries.
local get_criterion_value = {
train_cer = function(train_summary, valid_summary)
if not train_summary then return nil end
return train_summary.cer
end,
train_loss = function(train_summary, valid_summary)
if not train_summary then return nil end
return train_summary.loss
end,
valid_cer = function(train_summary, valid_summary)
if not valid_summary then return nil end
return valid_summary.cer
end,
valid_loss = function(train_summary, valid_summary)
if not valid_summary then return nil end
return valid_summary.loss
end
}
get_criterion_value = get_criterion_value[opts.best_criterion]
-- Current best criterion value, or nil
local best_criterion_value = get_criterion_value(
checkpoint:Best():getSummary('train'),
checkpoint:Best():getSummary('valid'))
local early_stop_last_significant_epoch = epoch
local early_stop_last_significant_value = best_criterion_value
-- Open progress table
if opts.progress_table_output ~= '' then
progress_table:open(
opts.progress_table_output,
epoch ~= 0 and io.open(opts.progress_table_output, 'r') ~= nil)
end
while opts.max_epochs <= 0 or epoch < opts.max_epochs do
if laia.SignalHandler.ExitRequested() then break end
-- Epoch starts at 0, when the model is created
epoch = epoch + 1
-- Apply learning rate decay
if opts.learning_rate_decay < 1 and
opts.learning_rate_decay_min < rmsprop_opts.learningRate and
epoch > opts.learning_rate_decay_after and
epoch % opts.learning_rate_decay_period == 0 then
rmsprop_opts.learningRate = math.max(
rmsprop_opts.learningRate * opts.learning_rate_decay,
opts.learning_rate_decay_min)
laia.log.info('Learning rate decay applied. New learning rate = %g',
rmsprop_opts.learningRate)
end
-- Show some memory usage messages before the training step
--[[
laia.log.info(
'Epoch %d, CPU memory usage before train: current %.1fMB, max %.1fMB',
epoch, laia.mem.getCurrentCPUMemory(), laia.mem.getMaxCPUMemory())
laia.log.info(
'Epoch %d, GPU memory usage before train: current %.1fMB, max %.1fMB',
epoch, laia.mem.getCurrentGPUMemory(), laia.mem.getMaxGPUMemory())
--]]
-- Train
local train_epoch_info = ctc_trainer:trainEpoch(rmsprop_opts)
if laia.SignalHandler.ExitRequested() then break end
local train_summary = epoch_summarizer:summarize(train_epoch_info)
laia.log.info('Epoch %d, train summary: %s',
epoch, laia.EpochSummarizer.ToString(train_summary))
-- Show some memory usage messages before the validation step
--[[
laia.log.info(
'Epoch %d, CPU memory usage before valid: current %.1fMB, max %.1fMB',
epoch, laia.mem.getCurrentCPUMemory(), laia.mem.getMaxCPUMemory())
laia.log.info(
'Epoch %d, GPU memory usage before valid: current %.1fMB, max %.1fMB',
epoch, laia.mem.getCurrentGPUMemory(), laia.mem.getMaxGPUMemory())
--]]
-- Valid (if possible)
local valid_epoch_info, valid_summary = nil, nil
if valid_batcher then
valid_epoch_info = ctc_trainer:validEpoch()
if laia.SignalHandler.ExitRequested() then break end
valid_summary = epoch_summarizer:summarize(valid_epoch_info)
laia.log.info('Epoch %d, valid summary: %s',
epoch, laia.EpochSummarizer.ToString(valid_summary))
end
-- Determine whether or not the new model is better than the previous ones.
local current_criterion_value = get_criterion_value(
train_summary, valid_summary)
if best_criterion_value == nil or
current_criterion_value < best_criterion_value then
laia.log.info('Epoch %d, new better model according to criterion %q: ' ..
'%f vs. %f (on epoch %d).',
epoch, opts.best_criterion,
current_criterion_value,
best_criterion_value or math.huge,
checkpoint:Best():getEpoch())
best_criterion_value = current_criterion_value
checkpoint:Best():setEpoch(epoch)
checkpoint:Best():setModel(model)
checkpoint:Best():addSummary('train', train_summary)
checkpoint:Best():addSummary('valid', valid_summary)
end
-- Save checkpoint: every --checkpoint_save_interval or when a new better
-- model is found.
if epoch % opts.checkpoint_save_interval == 0 or
checkpoint:Best():getEpoch() == epoch then
checkpoint:Last():setEpoch(epoch)
checkpoint:Last():setModel(model)
checkpoint:Last():addSummary('train', train_summary)
checkpoint:Last():addSummary('valid', valid_summary)
checkpoint:setRMSPropState(rmsprop_opts)
checkpoint:setRNGState(laia.getRNGState())
local checkpoint_filename = (opts.output_checkpoint ~= '' and
opts.output_checkpoint) or opts.checkpoint
laia.log.info('Epoch %d, saving checkpoint to %q.',
epoch, checkpoint_filename)
checkpoint:save(checkpoint_filename)
end
-- Write progress table row
if opts.progress_table_output ~= '' then
progress_table:write(epoch, train_summary, valid_summary,
checkpoint:Best():getEpoch() == epoch)
end
-- Early stopping strategy: keep track of the last *significant* result
if early_stop_last_significant_value == nil or
((early_stop_last_significant_value - best_criterion_value) /
early_stop_last_significant_value) > opts.early_stop_threshold then
laia.log.info(
'Epoch %d, new significantly better model according to %q criterion ' ..
'and relative threshold %.4f%%: %f vs. %f (on epoch %d).',
epoch, opts.best_criterion, opts.early_stop_threshold,
best_criterion_value,
early_stop_last_significant_value or math.huge,
early_stop_last_significant_epoch)
early_stop_last_significant_epoch = epoch
early_stop_last_significant_value = best_criterion_value
end
-- If the last significant result was achieved too long ago, stop.
if opts.early_stop_epochs > 0 and
epoch - early_stop_last_significant_epoch >= opts.early_stop_epochs then
laia.log.info('Epoch %d, last epoch with a significant improvement on ' ..
'%q criterion was %d. Triggering early stop!',
epoch, opts.best_criterion, early_stop_last_significant_epoch)
break;
end
-- Garbage collection from time to time.
if epoch % 50 == 0 then collectgarbage() end
end
-- If we actually did something
local p = opts.best_criterion:match('^([^_]+)_.*$')
if checkpoint:Best():getEpoch() > 0 and
checkpoint:Best():getSummary(p) ~= nil then
laia.log.info('Finished training after %d epochs. According to %q ' ..
'criterion, epoch %d was the best: %s', epoch,
opts.best_criterion, checkpoint:Best():getEpoch(),
laia.EpochSummarizer.ToString(checkpoint:Best():getSummary(p)))
else
laia.log.error('Training stopped on epoch %d, but the model was not ' ..
'updated.', epoch)
end