diff --git a/efficientdet/README.md b/efficientdet/README.md index 5f313931e..e81823e6f 100644 --- a/efficientdet/README.md +++ b/efficientdet/README.md @@ -238,7 +238,7 @@ You can run inference for a video and show the results online: // Run eval. !python main.py --mode=eval \ --model_name=${MODEL} --model_dir=${CKPT_PATH} \ - --validation_file_pattern=tfrecord/val* \ + --val_file_pattern=tfrecord/val* \ --val_json_file=annotations/instances_val2017.json You can also run eval on test-dev set with the following command: @@ -259,7 +259,7 @@ You can also run eval on test-dev set with the following command: # Also, test-dev has 20288 images rather than val 5000 images. !python main.py --mode=eval \ --model_name=${MODEL} --model_dir=${CKPT_PATH} \ - --validation_file_pattern=tfrecord/testdev* \ + --val_file_pattern=tfrecord/testdev* \ --testdev_dir='testdev_output' --eval_samples=20288 # Now you can submit testdev_output/detections_test-dev2017_test_results.json to # coco server: https://competitions.codalab.org/competitions/20794#participate @@ -288,8 +288,8 @@ Create a config file for the PASCAL VOC dataset called voc_config.yaml and put t Finetune needs to use --ckpt rather than --backbone_ckpt. !python main.py --mode=train_and_eval \ - --training_file_pattern=tfrecord/pascal*.tfrecord \ - --validation_file_pattern=tfrecord/pascal*.tfrecord \ + --train_file_pattern=tfrecord/pascal*.tfrecord \ + --val_file_pattern=tfrecord/pascal*.tfrecord \ --model_name=efficientdet-d0 \ --model_dir=/tmp/efficientdet-d0-finetune \ --ckpt=efficientdet-d0 \ @@ -326,8 +326,8 @@ Download efficientdet coco checkpoint. Finetune needs to use --ckpt rather than --backbone_ckpt. python main.py --mode=train \ - --training_file_pattern=tfrecord/pascal*.tfrecord \ - --validation_file_pattern=tfrecord/pascal*.tfrecord \ + --train_file_pattern=tfrecord/pascal*.tfrecord \ + --val_file_pattern=tfrecord/pascal*.tfrecord \ --model_name=efficientdet-d0 \ --model_dir=/tmp/efficientdet-d0-finetune \ --ckpt=efficientdet-d0 \ @@ -358,7 +358,7 @@ To train this model on Cloud TPU, you will need: Then train the model: !export PYTHONPATH="$PYTHONPATH:/path/to/models" - !python main.py --tpu=TPU_NAME --training_file_pattern=DATA_DIR/*.tfrecord --model_dir=MODEL_DIR --strategy=tpu + !python main.py --tpu=TPU_NAME --train_file_pattern=DATA_DIR/*.tfrecord --model_dir=MODEL_DIR --strategy=tpu # TPU_NAME is the name of the TPU node, the same name that appears when you run gcloud compute tpus list, or ctpu ls. # MODEL_DIR is a GCS location (a URL starting with gs:// where both the GCE VM and the associated Cloud TPU have write access. diff --git a/efficientdet/dataloader_test.py b/efficientdet/dataloader_test.py index 8f5f6c51b..066ff9d0f 100644 --- a/efficientdet/dataloader_test.py +++ b/efficientdet/dataloader_test.py @@ -13,63 +13,19 @@ # limitations under the License. # ============================================================================== """Data loader and processing test cases.""" -import os -import tempfile + import tensorflow as tf import dataloader import hparams_config -from dataset import tfrecord_util +from brain_automl.efficientdet import test_util + from keras import anchors from object_detection import tf_example_decoder class DataloaderTest(tf.test.TestCase): - def _make_fake_tfrecord(self): - tfrecord_path = os.path.join(tempfile.mkdtemp(), 'test.tfrecords') - writer = tf.io.TFRecordWriter(tfrecord_path) - encoded_jpg = tf.io.encode_jpeg(tf.ones([512, 512, 3], dtype=tf.uint8)) - example = tf.train.Example( - features=tf.train.Features( - feature={ - 'image/height': - tfrecord_util.int64_feature(512), - 'image/width': - tfrecord_util.int64_feature(512), - 'image/filename': - tfrecord_util.bytes_feature('test_file_name.jpg'.encode( - 'utf8')), - 'image/source_id': - tfrecord_util.bytes_feature('123456'.encode('utf8')), - 'image/key/sha256': - tfrecord_util.bytes_feature('qwdqwfw12345'.encode('utf8')), - 'image/encoded': - tfrecord_util.bytes_feature(encoded_jpg.numpy()), - 'image/format': - tfrecord_util.bytes_feature('jpeg'.encode('utf8')), - 'image/object/bbox/xmin': - tfrecord_util.float_list_feature([0.1]), - 'image/object/bbox/xmax': - tfrecord_util.float_list_feature([0.1]), - 'image/object/bbox/ymin': - tfrecord_util.float_list_feature([0.2]), - 'image/object/bbox/ymax': - tfrecord_util.float_list_feature([0.2]), - 'image/object/class/text': - tfrecord_util.bytes_list_feature(['test'.encode('utf8')]), - 'image/object/class/label': - tfrecord_util.int64_list_feature([1]), - 'image/object/difficult': - tfrecord_util.int64_list_feature([]), - 'image/object/truncated': - tfrecord_util.int64_list_feature([]), - 'image/object/view': - tfrecord_util.bytes_list_feature([]), - })) - writer.write(example.SerializeToString()) - return tfrecord_path - def test_parser(self): tf.random.set_seed(111111) params = hparams_config.get_detection_config('efficientdet-d0').as_dict() @@ -81,7 +37,7 @@ def test_parser(self): anchor_labeler = anchors.AnchorLabeler(input_anchors, params['num_classes']) example_decoder = tf_example_decoder.TfExampleDecoder( regenerate_source_id=params['regenerate_source_id']) - tfrecord_path = self._make_fake_tfrecord() + tfrecord_path = test_util.make_fake_tfrecord(self.get_temp_dir()) dataset = tf.data.TFRecordDataset([tfrecord_path]) value = next(iter(dataset)) reader = dataloader.InputReader(tfrecord_path, True) diff --git a/efficientdet/dataset/create_pascal_tfrecord.py b/efficientdet/dataset/create_pascal_tfrecord.py index d7000a1e6..22e27364a 100644 --- a/efficientdet/dataset/create_pascal_tfrecord.py +++ b/efficientdet/dataset/create_pascal_tfrecord.py @@ -33,19 +33,6 @@ from dataset import tfrecord_util -flags.DEFINE_string('data_dir', '', 'Root directory to raw PASCAL VOC dataset.') -flags.DEFINE_string('set', 'train', 'Convert training set, validation set or ' - 'merged set.') -flags.DEFINE_string('annotations_dir', 'Annotations', - '(Relative) path to annotations directory.') -flags.DEFINE_string('year', 'VOC2007', 'Desired challenge year.') -flags.DEFINE_string('output_path', '', 'Path to output TFRecord and json.') -flags.DEFINE_string('label_map_json_path', None, - 'Path to label map json file with a dictionary.') -flags.DEFINE_boolean('ignore_difficult_instances', False, 'Whether to ignore ' - 'difficult instances') -flags.DEFINE_integer('num_shards', 100, 'Number of shards for output file.') -flags.DEFINE_integer('num_images', None, 'Max number of imags to process.') FLAGS = flags.FLAGS SETS = ['train', 'val', 'trainval', 'test'] @@ -79,6 +66,24 @@ GLOBAL_ANN_ID = 0 # global annotation id. +def define_flags(): + """Define the flags.""" + flags.DEFINE_string('data_dir', '', + 'Root directory to raw PASCAL VOC dataset.') + flags.DEFINE_string('set', 'train', 'Convert training set, validation set or ' + 'merged set.') + flags.DEFINE_string('annotations_dir', 'Annotations', + '(Relative) path to annotations directory.') + flags.DEFINE_string('year', 'VOC2007', 'Desired challenge year.') + flags.DEFINE_string('output_path', '', 'Path to output TFRecord and json.') + flags.DEFINE_string('label_map_json_path', None, + 'Path to label map json file with a dictionary.') + flags.DEFINE_boolean('ignore_difficult_instances', False, 'Whether to ignore ' + 'difficult instances') + flags.DEFINE_integer('num_shards', 100, 'Number of shards for output file.') + flags.DEFINE_integer('num_images', None, 'Max number of imags to process.') + + def get_image_id(filename): """Convert a string to a integer.""" # Warning: this function is highly specific to pascal filename!! @@ -101,10 +106,9 @@ def get_ann_id(): def dict_to_tf_example(data, - dataset_directory, + images_dir, label_map_dict, ignore_difficult_instances=False, - image_subdirectory='JPEGImages', ann_json_dict=None): """Convert XML derived dict to tf.Example proto. @@ -114,12 +118,10 @@ def dict_to_tf_example(data, Args: data: dict holding PASCAL XML fields for a single image (obtained by running tfrecord_util.recursive_parse_xml_to_dict) - dataset_directory: Path to root directory holding PASCAL dataset + images_dir: Path to the directory holding raw images. label_map_dict: A map from string label names to integers ids. ignore_difficult_instances: Whether to skip difficult instances in the dataset (default: False). - image_subdirectory: String specifying subdirectory within the PASCAL dataset - directory holding the actual image data. ann_json_dict: annotation json dictionary. Returns: @@ -128,8 +130,7 @@ def dict_to_tf_example(data, Raises: ValueError: if the image pointed to by data['filename'] is not a valid JPEG """ - img_path = os.path.join(data['folder'], image_subdirectory, data['filename']) - full_path = os.path.join(dataset_directory, img_path) + full_path = os.path.join(images_dir, data['filename']) with tf.io.gfile.GFile(full_path, 'rb') as fid: encoded_jpg = fid.read() encoded_jpg_io = io.BytesIO(encoded_jpg) @@ -297,9 +298,10 @@ def main(_): xml = etree.fromstring(xml_str) data = tfrecord_util.recursive_parse_xml_to_dict(xml)['annotation'] + img_dir = os.path.join(FLAGS.data_dir, data['folder'], 'JPEGImages') tf_example = dict_to_tf_example( data, - FLAGS.data_dir, + img_dir, label_map_dict, FLAGS.ignore_difficult_instances, ann_json_dict=ann_json_dict) @@ -316,4 +318,5 @@ def main(_): if __name__ == '__main__': + define_flags() app.run(main) diff --git a/efficientdet/dataset/create_pascal_tfrecord_test.py b/efficientdet/dataset/create_pascal_tfrecord_test.py index 4b51b0b95..8a9fd3b7a 100644 --- a/efficientdet/dataset/create_pascal_tfrecord_test.py +++ b/efficientdet/dataset/create_pascal_tfrecord_test.py @@ -73,8 +73,9 @@ def test_dict_to_tf_example(self): 'notperson': 2, } - example = create_pascal_tfrecord.dict_to_tf_example( - data, self.get_temp_dir(), label_map_dict, image_subdirectory='') + example = create_pascal_tfrecord.dict_to_tf_example(data, + self.get_temp_dir(), + label_map_dict) self._assertProtoEqual( example.features.feature['image/height'].int64_list.value, [256]) self._assertProtoEqual( diff --git a/efficientdet/dataset/tfrecord_util.py b/efficientdet/dataset/tfrecord_util.py index 173918550..3ac91f04a 100644 --- a/efficientdet/dataset/tfrecord_util.py +++ b/efficientdet/dataset/tfrecord_util.py @@ -72,7 +72,7 @@ def recursive_parse_xml_to_dict(xml): Python dictionary holding XML contents. """ if not xml: - return {xml.tag: xml.text} + return {xml.tag: xml.text if xml.text else ''} result = {} for child in xml: child_result = recursive_parse_xml_to_dict(child) diff --git a/efficientdet/hparams_config.py b/efficientdet/hparams_config.py index 56b695321..73b7ec694 100644 --- a/efficientdet/hparams_config.py +++ b/efficientdet/hparams_config.py @@ -204,7 +204,7 @@ def default_detection_configs(): h.max_level = 7 h.num_scales = 3 # ratio w/h: 2.0 means w=1.4, h=0.7. Can be computed with k-mean per dataset. - h.aspect_ratios = [1.0, 2.0, 0.5] #[[0.7, 1.4], [1.0, 1.0], [1.4, 0.7]] + h.aspect_ratios = [1.0, 2.0, 0.5] # [[0.7, 1.4], [1.0, 1.0], [1.4, 0.7]] h.anchor_scale = 4.0 # is batchnorm training mode h.is_training_bn = True diff --git a/efficientdet/keras/README.md b/efficientdet/keras/README.md index 07f6786f5..c69d46600 100644 --- a/efficientdet/keras/README.md +++ b/efficientdet/keras/README.md @@ -245,7 +245,7 @@ Create a config file for the PASCAL VOC dataset called voc_config.yaml and put t Finetune needs to use --pretrained_ckpt. !python train.py - --training_file_pattern=tfrecord/pascal*.tfrecord \ + --train_file_pattern=tfrecord/pascal*.tfrecord \ --val_file_pattern=tfrecord/pascal*.tfrecord \ --val_file_pattern=tfrecord/*.json \ --model_name=efficientdet-d0 \ @@ -273,7 +273,7 @@ To train this model on Cloud TPU, you will need: Then train the model: !export PYTHONPATH="$PYTHONPATH:/path/to/models" - !python train.py --tpu=TPU_NAME --training_file_pattern=DATA_DIR/*.tfrecord --model_dir=MODEL_DIR --strategy=tpu + !python train.py --tpu=TPU_NAME --train_file_pattern=DATA_DIR/*.tfrecord --model_dir=MODEL_DIR --strategy=tpu # TPU_NAME is the name of the TPU node, the same name that appears when you run gcloud compute tpus list, or ctpu ls. # MODEL_DIR is a GCS location (a URL starting with gs:// where both the GCE VM and the associated Cloud TPU have write access. diff --git a/efficientdet/keras/efficientdet_keras.py b/efficientdet/keras/efficientdet_keras.py index 271d08c76..e0779959e 100644 --- a/efficientdet/keras/efficientdet_keras.py +++ b/efficientdet/keras/efficientdet_keras.py @@ -30,6 +30,17 @@ # pylint: disable=arguments-differ # fo keras layers. +def add_n(nodes): + """A customized add_n to add up a list of tensors.""" + # tf.add_n is not supported by EdgeTPU, while tf.reduce_sum is not supported + # by GPU and runs slow on EdgeTPU because of the 5-dimension op. + with tf.name_scope('add_n'): + new_node = nodes[0] + for n in nodes[1:]: + new_node = new_node + n + return new_node + + class FNode(tf.keras.layers.Layer): """A Keras Layer implementing BiFPN Node.""" @@ -89,12 +100,12 @@ def fuse_features(self, nodes): for var in self.vars: var = tf.cast(var, dtype=dtype) edge_weights.append(var) - weights_sum = tf.add_n(edge_weights) + weights_sum = add_n(edge_weights) nodes = [ nodes[i] * edge_weights[i] / (weights_sum + 0.0001) for i in range(len(nodes)) ] - new_node = tf.add_n(nodes) + new_node = add_n(nodes) elif self.weight_method == 'channel_attn': edge_weights = [] for var in self.vars: @@ -109,14 +120,14 @@ def fuse_features(self, nodes): var = tf.cast(var, dtype=dtype) edge_weights.append(var) - weights_sum = tf.add_n(edge_weights) + weights_sum = add_n(edge_weights) nodes = [ nodes[i] * edge_weights[i] / (weights_sum + 0.0001) for i in range(len(nodes)) ] - new_node = tf.add_n(nodes) + new_node = add_n(nodes) elif self.weight_method == 'sum': - new_node = tf.reduce_sum(nodes, axis=0) + new_node = add_n(nodes) else: raise ValueError('unknown weight_method %s' % self.weight_method) @@ -289,9 +300,9 @@ def _pool2d(self, inputs, height, width, target_height, target_width): def _upsample2d(self, inputs, target_height, target_width): return tf.cast( - tf.image.resize( - tf.cast(inputs, tf.float32), [target_height, target_width], - method=self.upsampling_type), inputs.dtype) + tf.compat.v1.image.resize_nearest_neighbor( + tf.cast(inputs, tf.float32), [target_height, target_width]), + inputs.dtype) def _maybe_apply_1x1(self, feat, training, num_channels): """Apply 1x1 conv to change layer width if necessary.""" @@ -349,6 +360,7 @@ def __init__(self, data_format='channels_last', grad_checkpoint=False, name='class_net', + feature_only=False, **kwargs): """Initialize the ClassNet. @@ -367,6 +379,8 @@ def __init__(self, data_format: string of 'channel_first' or 'channels_last'. grad_checkpoint: bool, If true, apply grad checkpoint for saving memory. name: the name of this layerl. + feature_only: build the base feature network only (excluding final class + head). **kwargs: other parameters. """ @@ -386,6 +400,7 @@ def __init__(self, self.conv_ops = [] self.bns = [] self.grad_checkpoint = grad_checkpoint + self.feature_only = feature_only if separable_conv: conv2d_layer = functools.partial( tf.keras.layers.SeparableConv2D, @@ -454,8 +469,10 @@ def call(self, inputs, training, **kwargs): image = inputs[level_id] for i in range(self.repeats): image = self._conv_bn_act(image, i, level_id, training) - - class_outputs.append(self.classes(image)) + if self.feature_only: + class_outputs.append(image) + else: + class_outputs.append(self.classes(image)) return class_outputs @@ -477,6 +494,7 @@ def __init__(self, data_format='channels_last', grad_checkpoint=False, name='box_net', + feature_only=False, **kwargs): """Initialize BoxNet. @@ -494,6 +512,8 @@ def __init__(self, data_format: string of 'channel_first' or 'channels_last'. grad_checkpoint: bool, If true, apply grad checkpoint for saving memory. name: Name of the layer. + feature_only: build the base feature network only (excluding box class + head). **kwargs: other parameters. """ @@ -511,6 +531,7 @@ def __init__(self, self.strategy = strategy self.data_format = data_format self.grad_checkpoint = grad_checkpoint + self.feature_only = feature_only self.conv_ops = [] self.bns = [] @@ -603,7 +624,11 @@ def call(self, inputs, training): image = inputs[level_id] for i in range(self.repeats): image = self._conv_bn_act(image, i, level_id, training) - box_outputs.append(self.boxes(image)) + + if self.feature_only: + box_outputs.append(image) + else: + box_outputs.append(self.boxes(image)) return box_outputs @@ -752,7 +777,11 @@ def _call(feats): class EfficientDetNet(tf.keras.Model): """EfficientDet keras network without pre/post-processing.""" - def __init__(self, model_name=None, config=None, name=''): + def __init__(self, + model_name=None, + config=None, + name='', + feature_only=False): """Initialize model.""" super().__init__(name=name) @@ -816,7 +845,8 @@ def __init__(self, model_name=None, config=None, name=''): survival_prob=config.survival_prob, strategy=config.strategy, grad_checkpoint=config.grad_checkpoint, - data_format=config.data_format) + data_format=config.data_format, + feature_only=feature_only) self.box_net = BoxNet( num_anchors=num_anchors, @@ -830,7 +860,8 @@ def __init__(self, model_name=None, config=None, name=''): survival_prob=config.survival_prob, strategy=config.strategy, grad_checkpoint=config.grad_checkpoint, - data_format=config.data_format) + data_format=config.data_format, + feature_only=feature_only) if head == 'segmentation': self.seg_head = SegmentationHead( diff --git a/efficientdet/keras/eval_tflite.py b/efficientdet/keras/eval_tflite.py new file mode 100644 index 000000000..f1d4cf3e0 --- /dev/null +++ b/efficientdet/keras/eval_tflite.py @@ -0,0 +1,143 @@ +# Copyright 2020 Google Research. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Eval libraries. Used for TFLite model without post-processing.""" +from absl import app +from absl import flags +from absl import logging +import numpy as np +import tensorflow as tf + +import coco_metric +import dataloader +import hparams_config +import utils + +from keras import anchors +from keras import label_util +from keras import postprocess + +flags.DEFINE_integer('eval_samples', None, 'Number of eval samples.') +flags.DEFINE_string('val_file_pattern', None, + 'Glob for eval tfrecords, e.g. coco/val-*.tfrecord.') +flags.DEFINE_string('val_json_file', None, + 'Groudtruth, e.g. annotations/instances_val2017.json.') +flags.DEFINE_string('model_name', 'efficientdet-d0', 'Model name to use.') +flags.DEFINE_string('tflite_path', None, 'Path to TFLite model.') +flags.DEFINE_string('hparams', '', 'Comma separated k=v pairs or a yaml file') +FLAGS = flags.FLAGS + +DEFAULT_SCALE, DEFAULT_ZERO_POINT = 0, 0 + + +class LiteRunner(object): + """Runs inference with TF Lite model.""" + + def __init__(self, tflite_model_path): + """Initializes Lite runner with tflite model file.""" + self.interpreter = tf.lite.Interpreter(tflite_model_path) + self.interpreter.allocate_tensors() + # Get input and output tensors. + self.input_details = self.interpreter.get_input_details() + self.output_details = self.interpreter.get_output_details() + + def run(self, image): + """Runs inference with Lite model.""" + interpreter = self.interpreter + input_details = self.input_details + output_details = self.output_details + + input_detail = input_details[0] + if input_detail['quantization'] != (DEFAULT_SCALE, DEFAULT_ZERO_POINT): + scale, zero_point = input_detail['quantization'] + image = image / scale + zero_point + image = np.array(image, dtype=input_detail['dtype']) + interpreter.set_tensor(input_detail['index'], image) + interpreter.invoke() + + def get_output(idx): + output_detail = output_details[idx] + output_tensor = interpreter.get_tensor(output_detail['index']) + if output_detail['quantization'] != (DEFAULT_SCALE, DEFAULT_ZERO_POINT): + # Dequantize the output + scale, zero_point = output_detail['quantization'] + output_tensor = output_tensor.astype(np.float32) + output_tensor = (output_tensor - zero_point) * scale + return output_tensor + + num_boxes = int(len(output_details) / 2) + cls_outputs, box_outputs = [], [] + for i in range(num_boxes): + cls_outputs.append(get_output(i)) + box_outputs.append(get_output(i + num_boxes)) + return cls_outputs, box_outputs + + +def main(_): + config = hparams_config.get_efficientdet_config(FLAGS.model_name) + config.override(FLAGS.hparams) + config.val_json_file = FLAGS.val_json_file + config.nms_configs.max_nms_inputs = anchors.MAX_DETECTION_POINTS + config.drop_remainder = False # eval all examples w/o drop. + config.image_size = utils.parse_image_size(config['image_size']) + + # Evaluator for AP calculation. + label_map = label_util.get_label_map(config.label_map) + evaluator = coco_metric.EvaluationMetric( + filename=config.val_json_file, label_map=label_map) + + # dataset + batch_size = 1 + ds = dataloader.InputReader( + FLAGS.val_file_pattern, + is_training=False, + max_instances_per_image=config.max_instances_per_image)( + config, batch_size=batch_size) + eval_samples = FLAGS.eval_samples + if eval_samples: + ds = ds.take((eval_samples + batch_size - 1) // batch_size) + + # Network + lite_runner = LiteRunner(FLAGS.tflite_path) + eval_samples = FLAGS.eval_samples or 5000 + pbar = tf.keras.utils.Progbar((eval_samples + batch_size - 1) // batch_size) + for i, (images, labels) in enumerate(ds): + cls_outputs, box_outputs = lite_runner.run(images) + detections = postprocess.generate_detections(config, cls_outputs, + box_outputs, + labels['image_scales'], + labels['source_ids']) + detections = postprocess.transform_detections(detections) + evaluator.update_state(labels['groundtruth_data'].numpy(), + detections.numpy()) + pbar.update(i) + + # compute the final eval results. + metrics = evaluator.result() + metric_dict = {} + for i, name in enumerate(evaluator.metric_names): + metric_dict[name] = metrics[i] + + if label_map: + for i, cid in enumerate(sorted(label_map.keys())): + name = 'AP_/%s' % label_map[cid] + metric_dict[name] = metrics[i + len(evaluator.metric_names)] + print(FLAGS.model_name, metric_dict) + + +if __name__ == '__main__': + flags.mark_flag_as_required('val_file_pattern') + flags.mark_flag_as_required('tflite_path') + logging.set_verbosity(logging.WARNING) + app.run(main) diff --git a/efficientdet/keras/inference.py b/efficientdet/keras/inference.py index 5f8bc7729..c830c921a 100644 --- a/efficientdet/keras/inference.py +++ b/efficientdet/keras/inference.py @@ -21,6 +21,7 @@ import numpy as np import tensorflow as tf +import dataloader import hparams_config import utils from keras import efficientdet_keras @@ -85,13 +86,15 @@ def __call__(self, imgs): class ExportModel(tf.Module): """Model to be exported as SavedModel/TFLite format.""" - def __init__(self, model): + def __init__(self, model, pre_mode='infer'): super().__init__() self.model = model + self.pre_mode = pre_mode @tf.function def __call__(self, imgs): - return self.model(imgs, training=False, post_mode='global') + return self.model( + imgs, training=False, pre_mode=self.pre_mode, post_mode='global') class ServingDriver: @@ -300,15 +303,21 @@ def freeze(self, func): _, graphdef = convert_variables_to_constants_v2_as_graph(func) return graphdef - def _get_model_and_spec(self): + def _get_model_and_spec(self, tflite=None): """Get model instance and export spec.""" - if self.only_network: + if self.only_network or tflite: image_size = utils.parse_image_size(self.params['image_size']) spec = tf.TensorSpec( shape=[self.batch_size, *image_size, 3], dtype=tf.float32, name='images') - export_model = ExportNetwork(self.model) + if self.only_network: + export_model = ExportNetwork(self.model) + else: + # If export tflite, we should remove preprocessing since TFLite doesn't + # support dynamic shape. + logging.info('Export model without preprocessing.') + export_model = ExportModel(self.model, pre_mode=None) return export_model, spec else: spec = tf.TensorSpec( @@ -319,15 +328,20 @@ def _get_model_and_spec(self): def export(self, output_dir: Text = None, tensorrt: Text = None, - tflite: Text = None): + tflite: Text = None, + file_pattern: Text = None, + num_calibration_steps: int = 2000): """Export a saved model, frozen graph, and potential tflite/tensorrt model. Args: output_dir: the output folder for saved model. tensorrt: If not None, must be {'FP32', 'FP16', 'INT8'}. tflite: Type for post-training quantization. + file_pattern: Glob for tfrecords, e.g. coco/val-*.tfrecord. + num_calibration_steps: Number of post-training quantization calibration + steps to run. """ - export_model, input_spec = self._get_model_and_spec() + export_model, input_spec = self._get_model_and_spec(tflite) image_size = utils.parse_image_size(self.params['image_size']) if output_dir: tf.saved_model.save( @@ -356,18 +370,34 @@ def export(self, converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.target_spec.supported_types = [tf.float16] elif tflite == 'INT8': - num_calibration_steps = 10 + if file_pattern: + config = hparams_config.get_efficientdet_config(self.model_name) + config.override(self.params) + ds = dataloader.InputReader( + file_pattern, + is_training=False, + max_instances_per_image=config.max_instances_per_image)( + config, batch_size=self.batch_size) + + def representative_dataset_gen(): + for image, _ in ds.take(num_calibration_steps): + yield [image] + else: # Used for debugging, can remove later. + logging.warn('Use real representative dataset instead of fake ones.') + num_calibration_steps = 10 + def representative_dataset_gen(): # rewrite this for real data. + for _ in range(num_calibration_steps): + yield [tf.ones(shape, dtype=input_spec.dtype)] - def representative_dataset_gen(): # rewrite this for real data. - for _ in range(num_calibration_steps): - yield [tf.ones(shape, dtype=input_spec.dtype)] converter.representative_dataset = representative_dataset_gen converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.inference_input_type = tf.uint8 converter.inference_output_type = tf.uint8 - converter.target_spec.supported_ops = [ - tf.lite.OpsSet.TFLITE_BUILTINS_INT8 - ] + supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + if not self.only_network: + supported_ops.append(tf.lite.OpsSet.TFLITE_BUILTINS) + converter.target_spec.supported_ops = supported_ops + else: raise ValueError(f'Invalid tflite {tflite}: must be FP32, FP16, INT8.') diff --git a/efficientdet/keras/inference_test.py b/efficientdet/keras/inference_test.py index 66bd42c87..47edce955 100644 --- a/efficientdet/keras/inference_test.py +++ b/efficientdet/keras/inference_test.py @@ -17,6 +17,7 @@ import tempfile from absl import logging import tensorflow as tf +from brain_automl.efficientdet import test_util from keras import efficientdet_keras from keras import inference @@ -39,10 +40,10 @@ def test_export(self): driver.load(saved_model_path) driver.load(os.path.join(saved_model_path, 'efficientdet-d0_frozen.pb')) - def test_export_tflite(self): + def test_export_tflite_only_network(self): saved_model_path = os.path.join(self.tmp_path, 'saved_model') driver = inference.ServingDriver( - 'efficientdet-d0', self.tmp_path, only_network=True) + 'efficientdet-lite0', self.tmp_path, only_network=True) driver.export(saved_model_path, tflite='FP32') self.assertTrue( tf.io.gfile.exists(os.path.join(saved_model_path, 'fp32.tflite'))) @@ -50,6 +51,23 @@ def test_export_tflite(self): driver.export(saved_model_path, tflite='FP16') self.assertTrue( tf.io.gfile.exists(os.path.join(saved_model_path, 'fp16.tflite'))) + tf.io.gfile.rmtree(saved_model_path) + tfrecord_path = test_util.make_fake_tfrecord(self.get_temp_dir()) + driver.export( + saved_model_path, + tflite='INT8', + file_pattern=[tfrecord_path], + num_calibration_steps=1) + self.assertTrue( + tf.io.gfile.exists(os.path.join(saved_model_path, 'int8.tflite'))) + + def test_export_tflite_with_post_processing(self): + saved_model_path = os.path.join(self.tmp_path, 'saved_model') + driver = inference.ServingDriver( + 'efficientdet-lite0', self.tmp_path, only_network=False) + driver.export(saved_model_path, tflite='FP32') + self.assertTrue( + tf.io.gfile.exists(os.path.join(saved_model_path, 'fp32.tflite'))) def test_inference(self): driver = inference.ServingDriver('efficientdet-d0', self.tmp_path) diff --git a/efficientdet/keras/inspector.py b/efficientdet/keras/inspector.py index 515019260..2d19b14f5 100644 --- a/efficientdet/keras/inspector.py +++ b/efficientdet/keras/inspector.py @@ -55,6 +55,11 @@ # For saved model. flags.DEFINE_string('saved_model_dir', None, 'Folder path for saved model.') flags.DEFINE_string('tflite', None, 'tflite type: {FP32, FP16, INT8}.') +flags.DEFINE_string('file_pattern', None, + 'Glob for tfrecords, e.g. coco/val-*.tfrecord.') +flags.DEFINE_integer( + 'num_calibration_steps', 2000, + 'Number of post-training quantization calibration steps to run.') flags.DEFINE_bool('debug', False, 'Debug mode.') flags.DEFINE_bool('only_network', False, 'Model only contains network') FLAGS = flags.FLAGS @@ -86,7 +91,8 @@ def main(_): model_dir = FLAGS.saved_model_dir if tf.io.gfile.exists(model_dir): tf.io.gfile.rmtree(model_dir) - driver.export(model_dir, FLAGS.tensorrt, FLAGS.tflite) + driver.export(model_dir, FLAGS.tensorrt, FLAGS.tflite, FLAGS.file_pattern, + FLAGS.num_calibration_steps) print('Model are exported to %s' % model_dir) elif FLAGS.mode == 'infer': image_file = tf.io.read_file(FLAGS.input_image) @@ -178,5 +184,5 @@ def main(_): if __name__ == '__main__': - logging.set_verbosity(logging.ERROR) + logging.set_verbosity(logging.INFO) app.run(main) diff --git a/efficientdet/keras/train.py b/efficientdet/keras/train.py index 6acd51b40..101788b1c 100644 --- a/efficientdet/keras/train.py +++ b/efficientdet/keras/train.py @@ -26,86 +26,91 @@ from keras import train_lib from keras import util_keras +FLAGS = flags.FLAGS -# Cloud TPU Cluster Resolvers -flags.DEFINE_string( - 'tpu', - default=None, - help='The Cloud TPU to use for training. This should be either the name ' - 'used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 ' - 'url.') -flags.DEFINE_string( - 'gcp_project', - default=None, - help='Project name for the Cloud TPU-enabled project. If not specified, we ' - 'will attempt to automatically detect the GCE project from metadata.') -flags.DEFINE_string( - 'tpu_zone', - default=None, - help='GCE zone where the Cloud TPU is located in. If not specified, we ' - 'will attempt to automatically detect the GCE project from metadata.') - -# Model specific paramenters -flags.DEFINE_string( - 'eval_master', - default='', - help='GRPC URL of the eval master. Set to an appropriate value when running' - ' on CPU/GPU') -flags.DEFINE_string('eval_name', default=None, help='Eval job name') -flags.DEFINE_enum('strategy', None, ['tpu', 'gpus', ''], - 'Training: gpus for multi-gpu, if None, use TF default.') - -flags.DEFINE_integer( - 'num_cores', default=8, help='Number of TPU cores for training') - -flags.DEFINE_bool('use_fake_data', False, 'Use fake input.') -flags.DEFINE_bool( - 'use_xla', False, - 'Use XLA even if strategy is not tpu. If strategy is tpu, always use XLA, ' - 'and this flag has no effect.') -flags.DEFINE_string('model_dir', None, 'Location of model_dir') - -flags.DEFINE_string('pretrained_ckpt', None, - 'Start training from this EfficientDet checkpoint.') - -flags.DEFINE_string( - 'hparams', '', 'Comma separated k=v pairs of hyperparameters or a module' - ' containing attributes to use as hyperparameters.') -flags.DEFINE_integer('batch_size', 64, 'training batch size') -flags.DEFINE_integer('eval_samples', 5000, 'The number of samples for ' - 'evaluation.') -flags.DEFINE_integer('steps_per_execution', 1000, - 'Number of steps per training execution.') -flags.DEFINE_string( - 'training_file_pattern', None, - 'Glob for training data files (e.g., COCO train - minival set)') -flags.DEFINE_string('val_file_pattern', None, - 'Glob for evaluation tfrecords (e.g., COCO val2017 set)') -flags.DEFINE_string( - 'val_json_file', None, - 'COCO validation JSON containing golden bounding boxes. If None, use the ' - 'ground truth from the dataloader. Ignored if testdev_dir is not None.') - -flags.DEFINE_string('mode', 'traineval', 'job mode: train, traineval.') -flags.DEFINE_integer('num_examples_per_epoch', 120000, - 'Number of examples in one epoch') -flags.DEFINE_integer('num_epochs', None, 'Number of epochs for training') -flags.DEFINE_string('model_name', 'efficientdet-d0', 'Model name.') -flags.DEFINE_bool('debug', False, 'Enable debug mode') -flags.DEFINE_integer( - 'tf_random_seed', 111111, - 'Fixed random seed for deterministic execution across runs for debugging.') -flags.DEFINE_bool('profile', False, 'Enable profile mode') -FLAGS = flags.FLAGS +def define_flags(): + """Define the flags.""" + # Cloud TPU Cluster Resolvers + flags.DEFINE_string( + 'tpu', + default=None, + help='The Cloud TPU to use for training. This should be either the name ' + 'used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 ' + 'url.') + flags.DEFINE_string( + 'gcp_project', + default=None, + help='Project name for the Cloud TPU-enabled project. If not specified, ' + 'we will attempt to automatically detect the GCE project from metadata.') + flags.DEFINE_string( + 'tpu_zone', + default=None, + help='GCE zone where the Cloud TPU is located in. If not specified, we ' + 'will attempt to automatically detect the GCE project from metadata.') + + # Model specific paramenters + flags.DEFINE_string( + 'eval_master', + default='', + help='GRPC URL of the eval master. Set to an appropriate value when ' + 'running on CPU/GPU') + flags.DEFINE_string('eval_name', default=None, help='Eval job name') + flags.DEFINE_enum('strategy', None, ['tpu', 'gpus', ''], + 'Training: gpus for multi-gpu, if None, use TF default.') + + flags.DEFINE_integer( + 'num_cores', default=8, help='Number of TPU cores for training') + + flags.DEFINE_bool('use_fake_data', False, 'Use fake input.') + flags.DEFINE_bool( + 'use_xla', False, + 'Use XLA even if strategy is not tpu. If strategy is tpu, always use XLA,' + ' and this flag has no effect.') + flags.DEFINE_string('model_dir', None, 'Location of model_dir') + + flags.DEFINE_string('pretrained_ckpt', None, + 'Start training from this EfficientDet checkpoint.') + + flags.DEFINE_string( + 'hparams', '', 'Comma separated k=v pairs of hyperparameters or a module' + ' containing attributes to use as hyperparameters.') + flags.DEFINE_integer('batch_size', 64, 'training batch size') + flags.DEFINE_integer('eval_samples', 5000, 'The number of samples for ' + 'evaluation.') + flags.DEFINE_integer('steps_per_execution', 200, + 'Number of steps per training execution.') + flags.DEFINE_string( + 'train_file_pattern', None, + 'Glob for train data files (e.g., COCO train - minival set)') + flags.DEFINE_string('val_file_pattern', None, + 'Glob for evaluation tfrecords (e.g., COCO val2017 set)') + flags.DEFINE_string( + 'val_json_file', None, + 'COCO validation JSON containing golden bounding boxes. If None, use the ' + 'ground truth from the dataloader. Ignored if testdev_dir is not None.') + flags.DEFINE_string('mode', 'traineval', 'job mode: train, traineval.') + flags.DEFINE_string( + 'hub_module_url', None, 'TF-Hub path/url to EfficientDet module.' + 'If specified, pretrained_ckpt flag should not be used.') + flags.DEFINE_integer('num_examples_per_epoch', 120000, + 'Number of examples in one epoch') + flags.DEFINE_integer('num_epochs', None, 'Number of epochs for training') + flags.DEFINE_string('model_name', 'efficientdet-d0', 'Model name.') + flags.DEFINE_bool('debug', False, 'Enable debug mode') + flags.DEFINE_integer( + 'tf_random_seed', 111111, + 'Fixed random seed for deterministic execution across runs for debugging.' + ) + flags.DEFINE_bool('profile', False, 'Enable profile mode') -def setup_model(config): + +def setup_model(model, config): """Build and compile model.""" - model = train_lib.EfficientDetNetTrain(config=config) model.build((None, *config.image_size, 3)) model.compile( - steps_per_execution=FLAGS.steps_per_execution, + steps_per_execution=config.steps_per_execution, optimizer=train_lib.get_optimizer(config.as_dict()), loss={ train_lib.BoxLoss.__name__: @@ -129,9 +134,8 @@ def setup_model(config): reduction=tf.keras.losses.Reduction.NONE), tf.keras.losses.SparseCategoricalCrossentropy.__name__: tf.keras.losses.SparseCategoricalCrossentropy( - from_logits=True, - reduction=tf.keras.losses.Reduction.NONE)} - ) + from_logits=True, reduction=tf.keras.losses.Reduction.NONE) + }) return model @@ -203,7 +207,7 @@ def main(_): def get_dataset(is_training, config): file_pattern = ( - FLAGS.training_file_pattern + FLAGS.train_file_pattern if is_training else FLAGS.val_file_pattern) if not file_pattern: raise ValueError('No matching files.') @@ -219,22 +223,57 @@ def get_dataset(is_training, config): with ds_strategy.scope(): if config.model_optimizations: tfmot.set_config(config.model_optimizations.as_dict()) - model = setup_model(config) - if FLAGS.pretrained_ckpt: + if FLAGS.hub_module_url: + model = train_lib.EfficientDetNetTrainHub( + config=config, hub_module_url=FLAGS.hub_module_url) + else: + model = train_lib.EfficientDetNetTrain(config=config) + model = setup_model(model, config) + if FLAGS.pretrained_ckpt and not FLAGS.hub_module_url: ckpt_path = tf.train.latest_checkpoint(FLAGS.pretrained_ckpt) util_keras.restore_ckpt(model, ckpt_path, config.moving_average_decay) init_experimental(config) - val_dataset = get_dataset(False, config) if 'eval' in FLAGS.mode else None - model.fit( - get_dataset(True, config), - epochs=config.num_epochs, - steps_per_epoch=steps_per_epoch, - callbacks=train_lib.get_callbacks(config.as_dict(), val_dataset), - validation_data=val_dataset, - validation_steps=(FLAGS.eval_samples // FLAGS.batch_size)) - model.save_weights(os.path.join(FLAGS.model_dir, 'ckpt-final')) + if 'train' in FLAGS.mode: + val_dataset = get_dataset(False, config) if 'eval' in FLAGS.mode else None + model.fit( + get_dataset(True, config), + epochs=config.num_epochs, + steps_per_epoch=steps_per_epoch, + callbacks=train_lib.get_callbacks(config.as_dict(), val_dataset), + validation_data=val_dataset, + validation_steps=(FLAGS.eval_samples // FLAGS.batch_size)) + else: + # Continuous eval. + for ckpt in tf.train.checkpoints_iterator( + FLAGS.model_dir, min_interval_secs=180): + logging.info('Starting to evaluate.') + # Terminate eval job when final checkpoint is reached. + try: + current_epoch = int(os.path.basename(ckpt).split('-')[1]) + except IndexError: + current_epoch = 0 + + val_dataset = get_dataset(False, config) + logging.info('start loading model.') + model.load_weights(tf.train.latest_checkpoint(FLAGS.model_dir)) + logging.info('finish loading model.') + coco_eval = train_lib.COCOCallback(val_dataset, 1) + coco_eval.set_model(model) + eval_results = coco_eval.on_epoch_end(current_epoch) + logging.info('eval results for %s: %s', ckpt, eval_results) + + try: + utils.archive_ckpt(eval_results, eval_results['AP'], ckpt) + except tf.errors.NotFoundError: + # Checkpoint might be not already deleted by the time eval finished. + logging.info('Checkpoint %s no longer exists, skipping.', ckpt) + + if current_epoch >= config.num_epochs or not current_epoch: + logging.info('Eval epoch %d / %d', current_epoch, config.num_epochs) + break if __name__ == '__main__': + define_flags() logging.set_verbosity(logging.INFO) app.run(main) diff --git a/efficientdet/keras/train_lib.py b/efficientdet/keras/train_lib.py index 80cc144a2..6f7d2fb5c 100644 --- a/efficientdet/keras/train_lib.py +++ b/efficientdet/keras/train_lib.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== """Training related libraries.""" +import functools import math import os import re @@ -22,6 +23,7 @@ import tensorflow as tf from tensorflow_addons.callbacks import AverageModelCheckpoint +import tensorflow_hub as hub import coco_metric import inference @@ -356,9 +358,12 @@ def on_epoch_end(self, epoch, logs=None): for (images, labels) in dataset: strategy.run(self._get_detections, (images, labels)) metrics = self.evaluator.result() + eval_results = {} with self.file_writer.as_default(), tf.summary.record_if(True): for i, name in enumerate(self.evaluator.metric_names): tf.summary.scalar(name, metrics[i], step=epoch) + eval_results[name] = metrics[i] + return eval_results class DisplayCallback(tf.keras.callbacks.Callback): @@ -407,14 +412,14 @@ def get_callbacks(params, val_dataset=None): """Get callbacks for given params.""" if params['moving_average_decay']: avg_callback = AverageModelCheckpoint( - filepath=os.path.join(params['model_dir'], 'ema_ckpt'), + filepath=os.path.join(params['model_dir'], 'emackpt-{epoch:d}'), verbose=1, save_weights_only=True, update_weights=False) callbacks = [avg_callback] else: ckpt_callback = tf.keras.callbacks.ModelCheckpoint( - os.path.join(params['model_dir'], 'ckpt'), + os.path.join(params['model_dir'], 'ckpt-{epoch:d}'), verbose=1, save_weights_only=True) callbacks = [ckpt_callback] @@ -436,7 +441,8 @@ def get_callbacks(params, val_dataset=None): params.get('sample_image', None), params['model_dir'], params['img_summary_steps']) callbacks.append(display_callback) - if params.get('map_freq', None) and val_dataset: + if (params.get('map_freq', None) and val_dataset and + params['strategy'] != 'tpu'): coco_callback = COCOCallback(val_dataset, params['map_freq']) callbacks.append(coco_callback) return callbacks @@ -843,3 +849,61 @@ def test_step(self, data): loss_vals['reg_l2_loss'] = reg_l2_loss loss_vals['loss'] = total_loss + tf.cast(reg_l2_loss, loss_dtype) return loss_vals + + +class EfficientDetNetTrainHub(EfficientDetNetTrain): + """EfficientDetNetTrain for Hub module.""" + + def __init__(self, config, hub_module_url, name=''): + super(efficientdet_keras.EfficientDetNet, self).__init__(name=name) + self.config = config + self.hub_module_url = hub_module_url + self.base_model = hub.KerasLayer(hub_module_url, trainable=True) + + # class/box output prediction network. + num_anchors = len(config.aspect_ratios) * config.num_scales + + if config.separable_conv: + conv2d_layer = functools.partial( + tf.keras.layers.SeparableConv2D, depth_multiplier=1) + else: + conv2d_layer = tf.keras.layers.Conv2D + self.classes = conv2d_layer( + config.num_classes * num_anchors, + kernel_size=3, + bias_initializer=tf.constant_initializer(-np.log((1 - 0.01) / 0.01)), + padding='same', + name='class_net/class-predict') + + if config.separable_conv: + self.boxes = tf.keras.layers.SeparableConv2D( + filters=4 * num_anchors, + depth_multiplier=1, + pointwise_initializer=tf.initializers.variance_scaling(), + depthwise_initializer=tf.initializers.variance_scaling(), + data_format=config.data_format, + kernel_size=3, + activation=None, + bias_initializer=tf.zeros_initializer(), + padding='same', + name='box_net/box-predict') + else: + self.boxes = tf.keras.layers.Conv2D( + filters=4 * num_anchors, + kernel_initializer=tf.random_normal_initializer(stddev=0.01), + data_format=config.data_format, + kernel_size=3, + activation=None, + bias_initializer=tf.zeros_initializer(), + padding='same', + name='box_net/box-predict') + + log_dir = os.path.join(self.config.model_dir, 'train_images') + self.summary_writer = tf.summary.create_file_writer(log_dir) + + def call(self, inputs, training): + cls_outputs, box_outputs = self.base_model(inputs, training=training) + for i in range(self.config.max_level - self.config.min_level + 1): + cls_outputs[i] = self.classes(cls_outputs[i]) + box_outputs[i] = self.boxes(box_outputs[i]) + return (cls_outputs, box_outputs) diff --git a/efficientdet/keras/util_keras.py b/efficientdet/keras/util_keras.py index fd8ea26ed..7c7483a56 100644 --- a/efficientdet/keras/util_keras.py +++ b/efficientdet/keras/util_keras.py @@ -132,9 +132,11 @@ def restore_ckpt(model, var_dict[v.name.split(':')[0]] = v # try to load graph-based checkpoint with ema support, # else load checkpoint via keras.load_weights which doesn't support ema. - for key, var in var_dict.items(): + for i, (key, var) in enumerate(var_dict.items()): try: var.assign(tf.train.load_variable(ckpt_path_or_file, key)) + if i < 10: + logging.info('Init %s from %s (%s)', var.name, key, ckpt_path_or_file) except tf.errors.NotFoundError as e: if skip_mismatch: logging.warning('Not found %s in %s', key, ckpt_path_or_file) diff --git a/efficientdet/main.py b/efficientdet/main.py index 230ef6dc9..81899104c 100644 --- a/efficientdet/main.py +++ b/efficientdet/main.py @@ -79,9 +79,9 @@ flags.DEFINE_integer('save_checkpoints_steps', 100, 'Number of iterations per checkpoint save') flags.DEFINE_string( - 'training_file_pattern', None, + 'train_file_pattern', None, 'Glob for training data files (e.g., COCO train - minival set)') -flags.DEFINE_string('validation_file_pattern', None, +flags.DEFINE_string('val_file_pattern', None, 'Glob for evaluation tfrecords (e.g., COCO val2017 set)') flags.DEFINE_string( 'val_json_file', None, @@ -95,7 +95,7 @@ flags.DEFINE_string('mode', 'train', 'Mode to run: train or eval (default: train)') flags.DEFINE_string('model_name', 'efficientdet-d1', 'Model name.') -flags.DEFINE_bool('eval_after_training', False, 'Run one eval after the ' +flags.DEFINE_bool('eval_after_train', False, 'Run one eval after the ' 'training finishes.') flags.DEFINE_bool('profile', False, 'Profile training performance.') flags.DEFINE_integer( @@ -131,11 +131,11 @@ def main(_): # Check data path if FLAGS.mode in ('train', 'train_and_eval'): - if FLAGS.training_file_pattern is None: - raise RuntimeError('Must specify --training_file_pattern for train.') + if FLAGS.train_file_pattern is None: + raise RuntimeError('Must specify --train_file_pattern for train.') if FLAGS.mode in ('eval', 'train_and_eval'): - if FLAGS.validation_file_pattern is None: - raise RuntimeError('Must specify --validation_file_pattern for eval.') + if FLAGS.val_file_pattern is None: + raise RuntimeError('Must specify --val_file_pattern for eval.') # Parse and override hparams config = hparams_config.get_detection_config(FLAGS.model_name) @@ -235,12 +235,12 @@ def _can_partition(spatial_dim): tf.io.gfile.GFile(config_file, 'w').write(str(config)) train_input_fn = dataloader.InputReader( - FLAGS.training_file_pattern, + FLAGS.train_file_pattern, is_training=True, use_fake_data=FLAGS.use_fake_data, max_instances_per_image=max_instances_per_image) eval_input_fn = dataloader.InputReader( - FLAGS.validation_file_pattern, + FLAGS.val_file_pattern, is_training=False, use_fake_data=FLAGS.use_fake_data, max_instances_per_image=max_instances_per_image) @@ -295,7 +295,7 @@ def get_estimator(global_batch_size): # start train/eval flow. if FLAGS.mode == 'train': train_est.train(input_fn=train_input_fn, max_steps=train_steps) - if FLAGS.eval_after_training: + if FLAGS.eval_after_train: eval_est.evaluate(input_fn=eval_input_fn, steps=eval_steps) elif FLAGS.mode == 'eval': diff --git a/efficientdet/test_util.py b/efficientdet/test_util.py new file mode 100644 index 000000000..55787074c --- /dev/null +++ b/efficientdet/test_util.py @@ -0,0 +1,65 @@ +# Copyright 2020 Google Research. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""Test utilities.""" +import os + +import tensorflow as tf +from dataset import tfrecord_util + + +def make_fake_tfrecord(temp_dir): + """Makes fake TFRecord to test input.""" + tfrecord_path = os.path.join(temp_dir, 'test.tfrecords') + writer = tf.io.TFRecordWriter(tfrecord_path) + encoded_jpg = tf.io.encode_jpeg(tf.ones([512, 512, 3], dtype=tf.uint8)) + example = tf.train.Example( + features=tf.train.Features( + feature={ + 'image/height': + tfrecord_util.int64_feature(512), + 'image/width': + tfrecord_util.int64_feature(512), + 'image/filename': + tfrecord_util.bytes_feature('test_file_name.jpg'.encode( + 'utf8')), + 'image/source_id': + tfrecord_util.bytes_feature('123456'.encode('utf8')), + 'image/key/sha256': + tfrecord_util.bytes_feature('qwdqwfw12345'.encode('utf8')), + 'image/encoded': + tfrecord_util.bytes_feature(encoded_jpg.numpy()), + 'image/format': + tfrecord_util.bytes_feature('jpeg'.encode('utf8')), + 'image/object/bbox/xmin': + tfrecord_util.float_list_feature([0.1]), + 'image/object/bbox/xmax': + tfrecord_util.float_list_feature([0.1]), + 'image/object/bbox/ymin': + tfrecord_util.float_list_feature([0.2]), + 'image/object/bbox/ymax': + tfrecord_util.float_list_feature([0.2]), + 'image/object/class/text': + tfrecord_util.bytes_list_feature(['test'.encode('utf8')]), + 'image/object/class/label': + tfrecord_util.int64_list_feature([1]), + 'image/object/difficult': + tfrecord_util.int64_list_feature([]), + 'image/object/truncated': + tfrecord_util.int64_list_feature([]), + 'image/object/view': + tfrecord_util.bytes_list_feature([]), + })) + writer.write(example.SerializeToString()) + return tfrecord_path diff --git a/efficientdet/tutorial.ipynb b/efficientdet/tutorial.ipynb index 2b5bf16da..598b71627 100644 --- a/efficientdet/tutorial.ipynb +++ b/efficientdet/tutorial.ipynb @@ -1230,7 +1230,7 @@ "# Evalute on validation set (takes about 10 mins for efficientdet-d0)\n", "!python main.py --mode=eval \\\n", " --model_name={MODEL} --model_dir={ckpt_path} \\\n", - " --validation_file_pattern=tfrecord/val* \\\n", + " --val_file_pattern=tfrecord/val* \\\n", " --val_json_file=annotations/instances_val2017.json" ], "execution_count": 0, @@ -1715,7 +1715,7 @@ " !mkdir testdev_output\n", " !python main.py --mode=eval \\\n", " --model_name={MODEL} --model_dir={ckpt_path} \\\n", - " --validation_file_pattern=tfrecord/testdev* \\\n", + " --val_file_pattern=tfrecord/testdev* \\\n", " --eval_batch_size=8 --eval_samples=20288 \\\n", " --testdev_dir='testdev_output'\n", " !rm -rf test2017 # delete images to release disk space.\n", @@ -1909,8 +1909,8 @@ "# key option: use --backbone_ckpt rather than --ckpt.\n", "# Don't use ema since we only train a few steps.\n", "!python main.py --mode=train_and_eval \\\n", - " --training_file_pattern=tfrecord/{file_pattern} \\\n", - " --validation_file_pattern=tfrecord/{file_pattern} \\\n", + " --train_file_pattern=tfrecord/{file_pattern} \\\n", + " --val_file_pattern=tfrecord/{file_pattern} \\\n", " --model_name={MODEL} \\\n", " --model_dir=/tmp/model_dir/{MODEL}-scratch \\\n", " --backbone_ckpt={backbone_name} \\\n", @@ -3044,8 +3044,8 @@ "!mkdir /tmp/model_dir/\n", "# key option: use --ckpt rather than --backbone_ckpt.\n", "!python main.py --mode=train_and_eval \\\n", - " --training_file_pattern=tfrecord/{file_pattern} \\\n", - " --validation_file_pattern=tfrecord/{file_pattern} \\\n", + " --train_file_pattern=tfrecord/{file_pattern} \\\n", + " --val_file_pattern=tfrecord/{file_pattern} \\\n", " --model_name={MODEL} \\\n", " --model_dir=/tmp/model_dir/{MODEL}-finetune \\\n", " --ckpt={MODEL} \\\n", diff --git a/efficientdet/utils.py b/efficientdet/utils.py index aca3f2d72..bfd17e19e 100644 --- a/efficientdet/utils.py +++ b/efficientdet/utils.py @@ -113,7 +113,7 @@ def get_ckpt_var_map(ckpt_path, ckpt_scope, var_scope, skip_mismatch=None): if tf.distribute.get_replica_context(): replica_id = tf.get_static_value( - tf.distribute.get_replica_context().replica_id_in_sync_group) + tf.distribute.get_replica_context().replica_id_in_sync_group) else: replica_id = 0 @@ -125,11 +125,11 @@ def get_ckpt_var_map(ckpt_path, ckpt_scope, var_scope, skip_mismatch=None): if not var_op_name.startswith(var_scope): logging.info('skip {} -- does not match scope {}'.format( - var_op_name, var_scope)) + var_op_name, var_scope)) ckpt_var = ckpt_scope + var_op_name[len(var_scope):] if (ckpt_var not in ckpt_var_names and - var_op_name.endswith('/ExponentialMovingAverage')): + var_op_name.endswith('/ExponentialMovingAverage')): ckpt_var = ckpt_scope + var_op_name[:-len('/ExponentialMovingAverage')] if ckpt_var not in ckpt_var_names: @@ -137,17 +137,18 @@ def get_ckpt_var_map(ckpt_path, ckpt_scope, var_scope, skip_mismatch=None): # Skip optimizer variables. continue if skip_mismatch: - logging.info('skip {} ({}) -- not in ckpt'.format(var_op_name, ckpt_var)) + logging.info('skip {} ({}) -- not in ckpt'.format( + var_op_name, ckpt_var)) continue raise ValueError('{} is not in ckpt {}'.format(v.op, ckpt_path)) if v.shape != ckpt_var_name_to_shape[ckpt_var]: if skip_mismatch: logging.info('skip {} ({} vs {}) -- shape mismatch'.format( - var_op_name, v.shape, ckpt_var_name_to_shape[ckpt_var])) + var_op_name, v.shape, ckpt_var_name_to_shape[ckpt_var])) continue raise ValueError('shape mismatch {} ({} vs {})'.format( - var_op_name, v.shape, ckpt_var_name_to_shape[ckpt_var])) + var_op_name, v.shape, ckpt_var_name_to_shape[ckpt_var])) if i < 5: # Log the first few elements for sanity check.