forked from rbgirshick/py-faster-rcnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_net.py
executable file
·149 lines (121 loc) · 5.06 KB
/
train_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
#!/usr/bin/env python
# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------
"""Train a Fast R-CNN network on a region of interest database."""
import argparse
import caffe
import os
import pprint
import sys
import tools._init_paths
import datasets.imdb
import numpy as np
from datasets.factory import get_imdb
from fast_rcnn.config import cfg, cfg_from_file, cfg_from_list, get_output_dir
from fast_rcnn.train import get_training_roidb, train_net
class ARGS:
def check_paths(self):
wait_list = [self.solver, self.cfg_file, self.pretrained_model]
for p in wait_list:
assert os.path.exists(p), "Path not found: '{}'".format(p)
def __init__(self, name, gpu_id=None, max_iters=None, pretrained_model=None):
'''basic args'''
self.__root = os.path.dirname(__file__)
'''uncertain args'''
self.gpu_id = 0 if gpu_id is None else gpu_id
self.max_iters = 70000 if max_iters is None else max_iters
self.pretrained_model = os.path.join(self.__root,
'data/imagenet_models/VGG16.v2.caffemodel') if pretrained_model is None else pretrained_model
'''args specified by name '''
self.solver = os.path.join(self.__root, 'models/{}/VGG16/faster_rcnn_end2end/solver.prototxt'.format(name))
self.imdb_name = '{}_train'.format(name)
self.cfg_file = os.path.join(self.__root, 'experiments/cfgs/{}_end2end.yml'.format(name))
'''immuatable args'''
self.randomize = True
self.set_cfgs = None
self.randomize = False
'''exsitence check'''
self.check_paths()
def parse_args():
"""
Parse input arguments
"""
parser = argparse.ArgumentParser(description='Train a Fast R-CNN network')
parser.add_argument('--gpu', dest='gpu_id',
help='GPU device id to use [0]',
default=0, type=int)
parser.add_argument('--solver', dest='solver',
help='solver prototxt',
default=None, type=str)
parser.add_argument('--iters', dest='max_iters',
help='number of iterations to train',
default=40000, type=int)
parser.add_argument('--weights', dest='pretrained_model',
help='initialize with pretrained model weights',
default=None, type=str)
parser.add_argument('--cfg', dest='cfg_file',
help='optional config file',
default=None, type=str)
parser.add_argument('--imdb', dest='imdb_name',
help='dataset to train on',
default='voc_2007_trainval', type=str)
parser.add_argument('--rand', dest='randomize',
help='randomize (do not use a fixed seed)',
action='store_true')
parser.add_argument('--set', dest='set_cfgs',
help='set config keys', default=None,
nargs=argparse.REMAINDER)
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
args = parser.parse_args()
return args
def combined_roidb(imdb_names):
def get_roidb(imdb_name):
imdb = get_imdb(imdb_name)
print 'Loaded dataset `{:s}` for training'.format(imdb.name)
imdb.set_proposal_method(cfg.TRAIN.PROPOSAL_METHOD)
print 'Set proposal method: {:s}'.format(cfg.TRAIN.PROPOSAL_METHOD)
roidb = get_training_roidb(imdb)
return roidb
roidbs = [get_roidb(s) for s in imdb_names.split('+')]
roidb = roidbs[0]
if len(roidbs) > 1:
for r in roidbs[1:]:
roidb.extend(r)
imdb = datasets.imdb.imdb(imdb_names)
else:
imdb = get_imdb(imdb_names)
return imdb, roidb
if __name__ == '__main__':
pretrained_model = '/home/ylxie/Space/work/py-faster-rcnn2/output/pls45/train/pls45_iter_10000.solverstate'
#pretrained_model = None
args = ARGS('pls45', gpu_id=0, pretrained_model=pretrained_model)
if args.cfg_file is not None:
cfg_from_file(args.cfg_file)
if args.set_cfgs is not None:
cfg_from_list(args.set_cfgs)
print cfg.TRAIN.BG_THRESH_LO
cfg.GPU_ID = args.gpu_id
#cfg.TRAIN.SNAPSHOT_ITERS = 20
print('Using config:')
pprint.pprint(cfg)
if not args.randomize:
print "set random"
# fix the random seeds (numpy and caffe) for reproducibility
np.random.seed(cfg.RNG_SEED)
caffe.set_random_seed(cfg.RNG_SEED)
# set up caffe
caffe.set_mode_gpu()
caffe.set_device(args.gpu_id)
imdb, roidb = combined_roidb(args.imdb_name)
print '{:d} roidb entries'.format(len(roidb))
output_dir = get_output_dir(imdb)
print 'Output will be saved to `{:s}`'.format(output_dir)
train_net(args.solver, roidb, output_dir,
pretrained_model=args.pretrained_model,
max_iters=args.max_iters)