diff --git a/efficientdet/keras/README.md b/efficientdet/keras/README.md
index e984fac42..0c05f94f1 100644
--- a/efficientdet/keras/README.md
+++ b/efficientdet/keras/README.md
@@ -3,16 +3,6 @@
[1] Mingxing Tan, Ruoming Pang, Quoc V. Le. EfficientDet: Scalable and Efficient Object Detection. CVPR 2020.
Arxiv link: https://arxiv.org/abs/1911.09070
-Updates:
-
- - **Jul20: Added keras/TF2 and new SOTA D7x: 55.1mAP with 153ms**
- - Apr22: Sped up end-to-end latency: D0 has up to >200 FPS throughput on Tesla V100.
- * A great collaboration with [@fsx950223](https://github.com/fsx950223).
- - Apr1: Updated results for test-dev and added EfficientDet-D7.
- - Mar26: Fixed a few bugs and updated all checkpoints/results.
- - Mar24: Added tutorial with visualization and coco eval.
- - Mar 13: Released the initial code and models.
-
**Quick start tutorial: [tutorial.ipynb](tutorial.ipynb)**
**Quick install dependencies: ```pip install -r requirements.txt```**
@@ -25,7 +15,7 @@ EfficientDets are a family of object detection models, which achieve state-of-th
EfficientDets are developed based on the advanced backbone, a new BiFPN, and a new scaling technique:
-
+
* **Backbone**: we employ [EfficientNets](https://arxiv.org/abs/1905.11946) as our backbone networks.
@@ -38,10 +28,10 @@ Our model family starts from EfficientDet-D0, which has comparable accuracy as [
@@ -56,15 +46,15 @@ We have provided a list of EfficientDet checkpoints and results as follows:
| Model | APtest | AP50 | AP75 |APS | APM | APL | APval | | #params | #FLOPs |
|---------- |------ |------ |------ | -------- | ------| ------| ------ |------ |------ | :------: |
-| EfficientDet-D0 ([h5](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d0.h5), [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d0.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d0_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d0_coco_test-dev2017.txt)) | 34.6 | 53.0 | 37.1 | 12.4 | 39.0 | 52.7 | 34.3 | | 3.9M | 2.54B |
-| EfficientDet-D1 ([h5](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d1.h5), [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d1.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d1_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d1_coco_test-dev2017.txt)) | 40.5 | 59.1 | 43.7 | 18.3 | 45.0 | 57.5 | 40.2 | | 6.6M | 6.10B |
-| EfficientDet-D2 ([h5](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d2.h5), [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d2.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d2_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d2_coco_test-dev2017.txt)) | 43.0 | 62.3 | 46.2 | 22.5 | 47.0 | 58.4 | 42.5 | | 8.1M | 11.0B |
-| EfficientDet-D3 ([h5](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d3.h5), [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d3.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d3_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d3_coco_test-dev2017.txt)) | 47.5 | 66.2 | 51.5 | 27.9 | 51.4 | 62.0 | 47.2 | | 12.0M | 24.9B |
-| EfficientDet-D4 ([h5](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d4.h5), [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d4.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d4_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d4_coco_test-dev2017.txt)) | 49.7 | 68.4 | 53.9 | 30.7 | 53.2 | 63.2 | 49.3 | | 20.7M | 55.2B |
-| EfficientDet-D5 ([h5](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d5.h5), [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d5.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d5_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d5_coco_test-dev2017.txt)) | 51.5 | 70.5 | 56.1 | 33.9 | 54.7 | 64.1 | 51.2 | | 33.7M | 130B |
-| EfficientDet-D6 ([h5](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d6.h5), [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d6.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d6_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d6_coco_test-dev2017.txt)) | 52.6 | 71.5 | 57.2 | 34.9 | 56.0 | 65.4 | 52.1 | | 51.9M | 226B |
-| EfficientDet-D7 ([h5](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d7.h5), [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d7.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d7_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d7_coco_test-dev2017.txt)) | 53.7 | 72.4 | 58.4 | 35.8 | 57.0 | 66.3 | 53.4 | | 51.9M | 325B |
-| EfficientDet-D7x ([h5](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d7x.h5), [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d7x.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d7x_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d7x_coco_test-dev2017.txt)) | 55.1 | 74.3 | 59.9 | 37.2 | 57.9 | 68.0 | 54.4 | | 77.0M | 410B |
+| EfficientDet-D0 ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d0.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d0_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d0_coco_test-dev2017.txt)) | 34.6 | 53.0 | 37.1 | 12.4 | 39.0 | 52.7 | 34.3 | | 3.9M | 2.54B |
+| EfficientDet-D1 ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d1.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d1_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d1_coco_test-dev2017.txt)) | 40.5 | 59.1 | 43.7 | 18.3 | 45.0 | 57.5 | 40.2 | | 6.6M | 6.10B |
+| EfficientDet-D2 ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d2.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d2_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d2_coco_test-dev2017.txt)) | 43.0 | 62.3 | 46.2 | 22.5 | 47.0 | 58.4 | 42.5 | | 8.1M | 11.0B |
+| EfficientDet-D3 ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d3.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d3_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d3_coco_test-dev2017.txt)) | 47.5 | 66.2 | 51.5 | 27.9 | 51.4 | 62.0 | 47.2 | | 12.0M | 24.9B |
+| EfficientDet-D4 ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d4.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d4_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d4_coco_test-dev2017.txt)) | 49.7 | 68.4 | 53.9 | 30.7 | 53.2 | 63.2 | 49.3 | | 20.7M | 55.2B |
+| EfficientDet-D5 ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d5.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d5_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d5_coco_test-dev2017.txt)) | 51.5 | 70.5 | 56.1 | 33.9 | 54.7 | 64.1 | 51.2 | | 33.7M | 130B |
+| EfficientDet-D6 ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d6.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d6_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d6_coco_test-dev2017.txt)) | 52.6 | 71.5 | 57.2 | 34.9 | 56.0 | 65.4 | 52.1 | | 51.9M | 226B |
+| EfficientDet-D7 ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d7.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d7_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d7_coco_test-dev2017.txt)) | 53.7 | 72.4 | 58.4 | 35.8 | 57.0 | 66.3 | 53.4 | | 51.9M | 325B |
+| EfficientDet-D7x ([ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d7x.tar.gz), [val](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/val/d7x_coco_val.txt), [test-dev](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/testdev/d7x_coco_test-dev2017.txt)) | 55.1 | 74.3 | 59.9 | 37.2 | 57.9 | 68.0 | 54.4 | | 77.0M | 410B |
val denotes validation results, test-dev denotes test-dev2017 results. APval is for validation accuracy, all other AP results in the table are for COCO test-dev2017. All accuracy numbers are for single-model single-scale without ensemble or test-time augmentation. EfficientDet-D0 to D6 are trained for 300 epochs and D7/D7x are trained for 600 epochs.
@@ -73,11 +63,11 @@ In addition, the following table includes a list of models trained with fixed 64
| Model | mAP | Latency |
| ------ | ------ | ------ |
-| D2(640) [h5](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco640/efficientdet-d2-640.h5), [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco640/efficientdet-d2-640.tar.gz) | 41.7 | 14.8ms |
-| D3(640) [h5](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco640/efficientdet-d3-640.h5), [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco640/efficientdet-d3-640.tar.gz) | 44.0 | 18.7ms |
-| D4(640) [h5](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco640/efficientdet-d4-640.h5), [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco640/efficientdet-d4-640.tar.gz) | 45.7 | 21.7ms |
-| D5(640 [h5](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco640/efficientdet-d5-640.h5), [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco640/efficientdet-d5-640.tar.gz) | 46.6 | 26.6ms |
-| D6(640) [h5](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco640/efficientdet-d6-640.h5), [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco640/efficientdet-d6-640.tar.gz) | 47.9 | 33.8ms |
+| D2(640) [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco640/efficientdet-d2-640.tar.gz) | 41.7 | 14.8ms |
+| D3(640) [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco640/efficientdet-d3-640.tar.gz) | 44.0 | 18.7ms |
+| D4(640) [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco640/efficientdet-d4-640.tar.gz) | 45.7 | 21.7ms |
+| D5(640) [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco640/efficientdet-d5-640.tar.gz) | 46.6 | 26.6ms |
+| D6(640) [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco640/efficientdet-d6-640.tar.gz) | 47.9 | 33.8ms |
@@ -96,10 +86,9 @@ Then you will get:
- saved model under `savedmodeldir/`
- frozen graph with name `savedmodeldir/efficientdet-d0_frozen.pb`
- TensorRT saved model under `savedmodeldir/tensorrt_fp32/`
- - tflite file with name `efficientdet-d0.tflite`
+ - tflite file with name `savedmodeldir/fp32.tflite`
Notably,
- --tflite_path only works after 2.3.0-dev20200521 ,
--model_dir=xx/archive is the folder for exporting the best model.
@@ -154,12 +143,12 @@ latency and throughput are:
# Step2: inference image.
!python inspector.py --mode=infer \
- --model_name=efficientdet-d0 --model_dir=efficientdet-d0 \
+ --model_name=efficientdet-d0 --saved_model_dir=/tmp/saved_model \
--hparams="image_size=1920x1280" \
--input_image=img.png --output_image_dir=/tmp/
-Alternatively, if you want to do inference using frozen graph instead of saved model, you can run
+If you want to do inference using frozen graph, you can run
# Step 1 is the same as before.
# Step 2: do inference with frozen graph.
@@ -168,18 +157,28 @@ Alternatively, if you want to do inference using frozen graph instead of saved m
--saved_model_dir=/tmp/saved_model/efficientdet-d0_frozen.pb \
--input_image=img.png --output_image_dir=/tmp/
+If you want to do inference using tflite, you can run
+
+ # Step 1 is the same as before.
+ # Step 2: do inference with frozen graph.
+ !python inspector.py --mode=infer \
+ --model_name=efficientdet-d0 \
+ --saved_model_dir=/tmp/saved_model/fp32.tflite \
+ --input_image=img.png --output_image_dir=/tmp/
+
Lastly, if you only have one image and just want to run a quick test, you can also run the following command (it is slow because it needs to construct the graph from scratch):
# Run inference for a single image.
- !python inspector.py --mode=infer --model_name=$MODEL \
+ !python inspector.py --mode=infer \
+ --model_name=efficientdet-d0 --model_dir=$CKPT_PATH \
--hparams="image_size=1920x1280" \
- --model_dir=$CKPT_PATH --input_image=img.png --output_image_dir=/tmp
+ --input_image=img.png --output_image_dir=/tmp/
# you can visualize the output /tmp/0.jpg
Here is an example of EfficientDet-D0 visualization: more on [tutorial](tutorial.ipynb)
-
+
## 6. Inference for videos.
@@ -243,14 +242,15 @@ Create a config file for the PASCAL VOC dataset called voc_config.yaml and put t
var_freeze_expr: '(efficientnet|fpn_cells|resample_p6)'
label_map: {1: aeroplane, 2: bicycle, 3: bird, 4: boat, 5: bottle, 6: bus, 7: car, 8: cat, 9: chair, 10: cow, 11: diningtable, 12: dog, 13: horse, 14: motorbike, 15: person, 16: pottedplant, 17: sheep, 18: sofa, 19: train, 20: tvmonitor}
-Finetune needs to use --ckpt rather than --backbone_ckpt.
+Finetune needs to use --pretrained_ckpt.
!python train.py
--training_file_pattern=tfrecord/pascal*.tfrecord \
--val_file_pattern=tfrecord/pascal*.tfrecord \
+ --val_file_pattern=tfrecord/*.json \
--model_name=efficientdet-d0 \
--model_dir=/tmp/efficientdet-d0-finetune \
- --ckpt=efficientdet-d0 \
+ --pretrained_ckpt=efficientdet-d0 \
--batch_size=64 \
--eval_samples=1024 \
--num_examples_per_epoch=5717 --num_epochs=50 \
@@ -258,52 +258,9 @@ Finetune needs to use --ckpt rather than --backbone_ckpt.
If you want to continue to train the model, simply re-run the above command because the `num_epochs` is a maximum number of epochs. For example, to reproduce the result of efficientdet-d0, set `--num_epochs=300` then run the command multiple times until the training is finished.
-If you want to do inference for custom data, you can run
-
- # Setting hparams-flag is needed sometimes.
- !python inspector.py --mode=infer \
- --model_name=efficientdet-d0 --model_dir=efficientdet-d0 \
- --hparams=voc_config.yaml \
- --input_image=img.png --output_image_dir=/tmp/
-
-You should check more details of runmode which is written in caption-4.
-
## 9. Train on multi GPUs.
-Create a config file for the PASCAL VOC dataset called voc_config.yaml and put this in it.
-
- num_classes: 21
- var_freeze_expr: '(efficientnet|fpn_cells|resample_p6)'
- label_map: {1: aeroplane, 2: bicycle, 3: bird, 4: boat, 5: bottle, 6: bus, 7: car, 8: cat, 9: chair, 10: cow, 11: diningtable, 12: dog, 13: horse, 14: motorbike, 15: person, 16: pottedplant, 17: sheep, 18: sofa, 19: train, 20: tvmonitor}
-
-Download efficientdet coco checkpoint.
-
- !wget https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco/efficientdet-d0.tar.gz
- !tar xf efficientdet-d0.tar.gz
-
-Finetune needs to use --ckpt rather than --backbone_ckpt.
-
- python train.py \
- --training_file_pattern=tfrecord/pascal*.tfrecord \
- --val_file_pattern=tfrecord/pascal*.tfrecord \
- --model_name=efficientdet-d0 \
- --model_dir=/tmp/efficientdet-d0-finetune \
- --ckpt=efficientdet-d0 \
- --batch_size=64 \
- --eval_samples=1024 \
- --num_examples_per_epoch=5717 --num_epochs=50 \
- --hparams=voc_config.yaml \
- --strategy=gpus
-
-If you want to do inference for custom data, you can run
-
- # Setting hparams-flag is needed sometimes.
- !python inspector.py --mode=infer \
- --model_name=efficientdet-d0 --model_dir=efficientdet-d0 \
- --hparams=voc_config.yaml \
- --input_image=img.png --output_image_dir=/tmp/
-
-You should check more details of runmode which is written in caption-4.
+Just add ```--strategy=gpus```
## 10. Training EfficientDets on TPUs.
@@ -335,7 +292,7 @@ EfficientDets use a lot of GPU memory for a few reasons:
* Large internal activations for backbone: our backbone uses a relatively large expansion ratio (6), causing the large expanded activations.
* Deep BiFPN: our BiFPN has multiple top-down and bottom-up paths, which leads to a lot of intermediate memory usage during training.
-To train this model on GPU with low memory there is an experimental option gradient_checkpointing.
+To train this model on GPU with low memory there is an experimental option grad_checkpoint.
Check these links for a high-level idea of what gradient checkpointing is doing:
1. https://medium.com/tensorflow/fitting-larger-networks-into-memory-583e3c758ff9
diff --git a/efficientdet/keras/eval.py b/efficientdet/keras/eval.py
index 7153df6d0..79871a1a4 100644
--- a/efficientdet/keras/eval.py
+++ b/efficientdet/keras/eval.py
@@ -76,6 +76,7 @@ def main(_):
model.build((None, *config.image_size, 3))
util_keras.restore_ckpt(model,
tf.train.latest_checkpoint(FLAGS.model_dir),
+ config.moving_average_decay,
skip_mismatch=False)
@tf.function
def model_fn(images, labels):
diff --git a/efficientdet/keras/inference.py b/efficientdet/keras/inference.py
index 6b9af12b1..cc377e3c1 100644
--- a/efficientdet/keras/inference.py
+++ b/efficientdet/keras/inference.py
@@ -202,7 +202,9 @@ def build(self, params_override=None):
self.model = efficientdet_keras.EfficientDetModel(config=config)
image_size = utils.parse_image_size(params['image_size'])
self.model.build((self.batch_size, *image_size, 3))
- util_keras.restore_ckpt(self.model, self.ckpt_path, skip_mismatch=False)
+ util_keras.restore_ckpt(self.model, self.ckpt_path,
+ self.params['moving_average_decay'],
+ skip_mismatch=False)
def visualize(self, image, boxes, classes, scores, **kwargs):
"""Visualize prediction on image."""
diff --git a/efficientdet/keras/train.py b/efficientdet/keras/train.py
index 6fcd59054..049640f46 100644
--- a/efficientdet/keras/train.py
+++ b/efficientdet/keras/train.py
@@ -220,7 +220,7 @@ def get_dataset(is_training, config):
model = setup_model(config)
if FLAGS.pretrained_ckpt:
ckpt_path = tf.train.latest_checkpoint(FLAGS.pretrained_ckpt)
- util_keras.restore_ckpt(model, ckpt_path)
+ util_keras.restore_ckpt(model, ckpt_path, config.moving_average_decay)
init_experimental(config)
val_dataset = get_dataset(False, config).repeat()
model.fit(
diff --git a/efficientdet/keras/train_lib.py b/efficientdet/keras/train_lib.py
index 54f2f3a88..f55fde007 100644
--- a/efficientdet/keras/train_lib.py
+++ b/efficientdet/keras/train_lib.py
@@ -398,13 +398,13 @@ def _draw_inference(self, step):
def get_callbacks(params, val_dataset):
"""Get callbacks for given params."""
- if False:
+ if params['moving_average_decay']:
from tensorflow_addons.callbacks import AverageModelCheckpoint
avg_callback = AverageModelCheckpoint(
filepath=os.path.join(params['model_dir'], 'ckpt'),
verbose=1,
save_weights_only=True,
- update_weights=True)
+ update_weights=False)
callbacks = [avg_callback]
else:
ckpt_callback = tf.keras.callbacks.ModelCheckpoint(
diff --git a/efficientdet/keras/util_keras.py b/efficientdet/keras/util_keras.py
index 5cde4b45f..79cd2b411 100644
--- a/efficientdet/keras/util_keras.py
+++ b/efficientdet/keras/util_keras.py
@@ -93,7 +93,7 @@ def average_name(ema, var):
var.name.split(':')[0] + '/' + ema.name, mark_as_used=False)
-def restore_ckpt(model, ckpt_path_or_file, ema_decay=0., skip_mismatch=True):
+def restore_ckpt(model, ckpt_path_or_file, ema_decay=0.9998, skip_mismatch=True):
"""Restore variables from a given checkpoint.
Args: