-
Notifications
You must be signed in to change notification settings - Fork 50
/
Copy pathtrain.py
40 lines (35 loc) · 1.4 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
'''
This script trains the neural net on the train and test set created
by create_data_splits.py using the auxiliary files created in
prepare_data.py and the prototxt files created manually
(see prototxt_generation_instructions.md).
'''
import gflags
from gflags import FLAGS
from flags import set_gflags
from os.path import join, dirname, abspath, exists
from os import system
gflags.DEFINE_string('dataset', None, 'The name of the dataset on which'
' you are training, e.g. nin-clean')
gflags.MarkFlagAsRequired('dataset')
gflags.DEFINE_boolean('time', False, 'Set to true if you are interested in '
'dissecting the runtime')
gflags.DEFINE_string('snapshot', None, 'If training got interrupted, '
'resume from this snapshot. This snapshot is a .solverstate file.')
gflags.DEFINE_string('pretrained_caffemodel', None, 'The path to a .caffemodel '
'you want to finetune')
ROOT = dirname(abspath(__file__))
if __name__ == '__main__':
set_gflags()
aux_dir = join(ROOT, 'aux', FLAGS.dataset)
cmd = join(ROOT, 'caffe/.build_release/tools/caffe.bin')
if FLAGS.time:
cmd += ' time --model=' + join(aux_dir, 'train_val.prototxt')
else:
cmd += ' train --solver=' + join(aux_dir, 'solver.prototxt')
if FLAGS.snapshot is not None:
cmd += ' --snapshot=' + FLAGS.snapshot
if FLAGS.pretrained_caffemodel is not None:
cmd += ' --weights=' + FLAGS.pretrained_caffemodel
print cmd
system(cmd)