-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy patheval.py
188 lines (150 loc) · 7.05 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Evaluation executable for detection models.
This executable is used to evaluate DetectionModels. There are two ways of
configuring the eval job.
'run_mode' is added to evaluate all checkpoints in dir.
1) A single pipeline_pb2.TrainEvalPipelineConfig file maybe specified instead.
In this mode, the --eval_training_data flag may be given to force the pipeline
to evaluate on training data instead.
Example usage:
./eval \
--logtostderr \
--checkpoint_dir=path/to/checkpoint_dir \
--eval_dir=path/to/eval_dir \
--pipeline_config_path=pipeline_config.pbtxt
2) Three configuration files may be provided: a model_pb2.DetectionModel
configuration file to define what type of DetectionModel is being evaluated, an
input_reader_pb2.InputReader file to specify what data the model is evaluating
and an eval_pb2.EvalConfig file to configure evaluation parameters.
Example usage:
./eval \
--logtostderr \
--checkpoint_dir=path/to/checkpoint_dir \
--eval_dir=path/to/eval_dir \
--eval_config_path=eval_config.pbtxt \
--model_config_path=model_config.pbtxt \
--input_config_path=eval_input_config.pbtxt
"""
import os
import sys
sys.path.append('./slim')
import shutil
import functools
import tensorflow as tf
from google.protobuf import text_format
import evaluator
from builders import input_reader_builder
from builders import model_builder
from protos import eval_pb2
from protos import input_reader_pb2
from protos import model_pb2
from protos import pipeline_pb2
from utils import label_map_util
tf.logging.set_verbosity(tf.logging.INFO)
flags = tf.app.flags
flags.DEFINE_boolean('eval_training_data', False,
'If training data should be evaluated for this job.')
flags.DEFINE_string('checkpoint_dir', '',
'Directory containing checkpoints to evaluate, typically '
'set to `train_dir` used in the training job.')
flags.DEFINE_string('eval_dir', '',
'Directory to write eval summaries to.')
flags.DEFINE_string('pipeline_config_path', '',
'Path to a pipeline_pb2.TrainEvalPipelineConfig config '
'file. If provided, other configs are ignored')
flags.DEFINE_string('eval_config_path', '',
'Path to an eval_pb2.EvalConfig config file.')
flags.DEFINE_string('input_config_path', '',
'Path to an input_reader_pb2.InputReader config file.')
flags.DEFINE_string('model_config_path', '',
'Path to a model_pb2.DetectionModel config file.')
flags.DEFINE_string('run_mode', 'all',
'When run_mode is latest, it run infinite and the latest checkpoint is evaluated.'
'When run_mode is latest_once, latest checkpoint is evaluated and it finishes evaluation.'
'When run_mode is all, all checkpoints are evaluated and it finishes evaluation.')
flags.DEFINE_boolean('save_detection_results', False,
'Whether or not to save detection results.')
flags.DEFINE_string('detection_results_name', '',
'Filename to a detection_results pickle file.')
flags.DEFINE_boolean('clean_dir', False,
'Whether to clean(delete) the train_dir')
FLAGS = flags.FLAGS
def get_configs_from_pipeline_file():
"""Reads evaluation configuration from a pipeline_pb2.TrainEvalPipelineConfig.
Reads evaluation config from file specified by pipeline_config_path flag.
Returns:
model_config: a model_pb2.DetectionModel
eval_config: a eval_pb2.EvalConfig
input_config: a input_reader_pb2.InputReader
"""
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
with tf.gfile.GFile(FLAGS.pipeline_config_path, 'r') as f:
text_format.Merge(f.read(), pipeline_config)
model_config = pipeline_config.model
if FLAGS.eval_training_data:
eval_config = pipeline_config.train_config
else:
eval_config = pipeline_config.eval_config
input_config = pipeline_config.eval_input_reader
return model_config, eval_config, input_config
def get_configs_from_multiple_files():
"""Reads evaluation configuration from multiple config files.
Reads the evaluation config from the following files:
model_config: Read from --model_config_path
eval_config: Read from --eval_config_path
input_config: Read from --input_config_path
Returns:
model_config: a model_pb2.DetectionModel
eval_config: a eval_pb2.EvalConfig
input_config: a input_reader_pb2.InputReader
"""
eval_config = eval_pb2.EvalConfig()
with tf.gfile.GFile(FLAGS.eval_config_path, 'r') as f:
text_format.Merge(f.read(), eval_config)
model_config = model_pb2.DetectionModel()
with tf.gfile.GFile(FLAGS.model_config_path, 'r') as f:
text_format.Merge(f.read(), model_config)
input_config = input_reader_pb2.InputReader()
with tf.gfile.GFile(FLAGS.input_config_path, 'r') as f:
text_format.Merge(f.read(), input_config)
return model_config, eval_config, input_config
def main(unused_argv):
assert FLAGS.checkpoint_dir, '`checkpoint_dir` is missing.'
assert FLAGS.eval_dir, '`eval_dir` is missing.'
if FLAGS.clean_dir:
shutil.rmtree(FLAGS.eval_dir, ignore_errors=True)
if FLAGS.pipeline_config_path:
model_config, eval_config, input_config = get_configs_from_pipeline_file()
else:
model_config, eval_config, input_config = get_configs_from_multiple_files()
if FLAGS.save_detection_results and FLAGS.detection_results_name == '':
FLAGS.detection_results_name = os.path.join(FLAGS.eval_dir, 'detection_results.pkl')
model_fn = functools.partial(
model_builder.build,
model_config=model_config,
is_training=False)
create_input_dict_fn = functools.partial(
input_reader_builder.build,
input_config)
label_map = label_map_util.load_labelmap(input_config.label_map_path)
max_num_classes = max([item.id for item in label_map.item])
categories = label_map_util.convert_label_map_to_categories(
label_map, max_num_classes)
evaluator.evaluate(create_input_dict_fn, model_fn, eval_config, input_config, categories,
FLAGS.checkpoint_dir, FLAGS.eval_dir, FLAGS.run_mode,
FLAGS.save_detection_results, FLAGS.detection_results_name)
if __name__ == '__main__':
tf.app.run()