Skip to content

Commit

Permalink
update for tf2.4 (#908)
Browse files Browse the repository at this point in the history
* update for tf2.4

* fix mixed precision with recompute gradient

* update README

* fix multi gpus training

* update README

* fix LossScaleOptimizer bug

* disable steps_per_execution in default

* split all reduce
  • Loading branch information
fsx950223 authored Dec 27, 2020
1 parent 53753bb commit 539ab65
Show file tree
Hide file tree
Showing 13 changed files with 170 additions and 86 deletions.
2 changes: 1 addition & 1 deletion efficientdet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ For more instructions about training on TPUs, please refer to the following tuto

* EfficientNet tutorial: https://cloud.google.com/tpu/docs/tutorials/efficientnet

## 11. Reducing Memory Usage when Training EfficientDets on GPU. (The current approach doesn't support mirrored multi GPU or mixed-precision training)
## 11. Reducing Memory Usage when Training EfficientDets on GPU.

EfficientDets use a lot of GPU memory for a few reasons:

Expand Down
129 changes: 88 additions & 41 deletions efficientdet/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ def __init__(self, image, output_size):
self._crop_offset_y = tf.constant(0)
self._crop_offset_x = tf.constant(0)

@property
def image(self):
return self._image

@image.setter
def image(self, image):
self._image = image

def normalize_image(self):
"""Normalize the image to zero mean and unit variance."""
# The image normalization is identical to Cloud TPU ResNet.
Expand All @@ -61,6 +69,7 @@ def normalize_image(self):
scale = tf.expand_dims(scale, axis=0)
scale = tf.expand_dims(scale, axis=0)
self._image /= scale
return self._image

def set_training_random_scale_factors(self,
scale_min,
Expand Down Expand Up @@ -126,6 +135,7 @@ def set_scale_factors_to_output_size(self):

def resize_and_crop_image(self, method=tf.image.ResizeMethod.BILINEAR):
"""Resize input image and crop it to the self._output dimension."""
dtype = self._image.dtype
scaled_image = tf.image.resize(
self._image, [self._scaled_height, self._scaled_width], method=method)
scaled_image = scaled_image[self._crop_offset_y:self._crop_offset_y +
Expand All @@ -135,7 +145,8 @@ def resize_and_crop_image(self, method=tf.image.ResizeMethod.BILINEAR):
output_image = tf.image.pad_to_bounding_box(scaled_image, 0, 0,
self._output_size[0],
self._output_size[1])
return output_image
self._image = tf.cast(output_image, dtype)
return self._image


class DetectionInputProcessor(InputProcessor):
Expand Down Expand Up @@ -245,6 +256,70 @@ def __init__(self,
self._max_instances_per_image = max_instances_per_image or 100
self._debug = debug

def _common_image_process(self, image, classes, boxes, data, params):
# Training time preprocessing.
if params['skip_crowd_during_training']:
indices = tf.where(tf.logical_not(data['groundtruth_is_crowd']))
classes = tf.gather_nd(classes, indices)
boxes = tf.gather_nd(boxes, indices)

if params.get('grid_mask', None):
from aug import gridmask # pylint: disable=g-import-not-at-top
image, boxes = gridmask.gridmask(image, boxes)

if params.get('autoaugment_policy', None):
from aug import autoaugment # pylint: disable=g-import-not-at-top
if params['autoaugment_policy'] == 'randaug':
image, boxes = autoaugment.distort_image_with_randaugment(
image, boxes, num_layers=1, magnitude=15)
else:
image, boxes = autoaugment.distort_image_with_autoaugment(
image, boxes, params['autoaugment_policy'])
return image, boxes, classes

def _resize_image_first(self, image, classes, boxes, data, params):
input_processor = DetectionInputProcessor(image, params['image_size'],
boxes, classes)
if self._is_training:
if params['input_rand_hflip']:
input_processor.random_horizontal_flip()

input_processor.set_training_random_scale_factors(
params['jitter_min'], params['jitter_max'],
params.get('target_size', None))
else:
input_processor.set_scale_factors_to_output_size()

image = input_processor.resize_and_crop_image()
boxes, classes = input_processor.resize_and_crop_boxes()

if self._is_training:
image, boxes, classes = self._common_image_process(image, classes, boxes, data, params)

input_processor.image = image
image = input_processor.normalize_image()
return image, boxes, classes, input_processor.image_scale_to_original

def _resize_image_last(self, image, classes, boxes, data, params):
if self._is_training:
image, boxes, classes = self._common_image_process(image, classes, boxes, data, params)

input_processor = DetectionInputProcessor(image, params['image_size'],
boxes, classes)
if self._is_training:
if params['input_rand_hflip']:
input_processor.random_horizontal_flip()

input_processor.set_training_random_scale_factors(
params['jitter_min'], params['jitter_max'],
params.get('target_size', None))
else:
input_processor.set_scale_factors_to_output_size()
input_processor.normalize_image()
image = input_processor.resize_and_crop_image()
boxes, classes = input_processor.resize_and_crop_boxes()
return image, boxes, classes, input_processor.image_scale_to_original

@tf.autograph.experimental.do_not_convert
def dataset_parser(self, value, example_decoder, anchor_labeler, params):
"""Parse data to a fixed dimension input image and learning targets.
Expand Down Expand Up @@ -293,41 +368,14 @@ def dataset_parser(self, value, example_decoder, anchor_labeler, params):
is_crowds = data['groundtruth_is_crowd']
image_masks = data.get('groundtruth_instance_masks', [])
classes = tf.reshape(tf.cast(classes, dtype=tf.float32), [-1, 1])

if self._is_training:
# Training time preprocessing.
if params['skip_crowd_during_training']:
indices = tf.where(tf.logical_not(data['groundtruth_is_crowd']))
classes = tf.gather_nd(classes, indices)
boxes = tf.gather_nd(boxes, indices)

if params.get('grid_mask', None):
from aug import gridmask # pylint: disable=g-import-not-at-top
image, boxes = gridmask.gridmask(image, boxes)

if params.get('autoaugment_policy', None):
from aug import autoaugment # pylint: disable=g-import-not-at-top
if params['autoaugment_policy'] == 'randaug':
image, boxes = autoaugment.distort_image_with_randaugment(
image, boxes, num_layers=1, magnitude=15)
else:
image, boxes = autoaugment.distort_image_with_autoaugment(
image, boxes, params['autoaugment_policy'])

input_processor = DetectionInputProcessor(image, params['image_size'],
boxes, classes)
input_processor.normalize_image()
if self._is_training:
if params['input_rand_hflip']:
input_processor.random_horizontal_flip()

input_processor.set_training_random_scale_factors(
params['jitter_min'], params['jitter_max'],
params.get('target_size', None))
else:
input_processor.set_scale_factors_to_output_size()
image = input_processor.resize_and_crop_image()
boxes, classes = input_processor.resize_and_crop_boxes()
source_area = tf.shape(image)[0] * tf.shape(image)[1]
target_size = utils.parse_image_size(params['image_size'])
target_area = target_size[0] * target_size[1]
# set condition in order to always process small
# first which could speed up pipeline
image, boxes, classes, image_scale = tf.cond(source_area > target_area,
lambda: self._resize_image_first(image, classes, boxes, data, params),
lambda: self._resize_image_last(image, classes, boxes, data, params))

# Assign anchors.
(cls_targets, box_targets,
Expand All @@ -338,7 +386,6 @@ def dataset_parser(self, value, example_decoder, anchor_labeler, params):
source_id = tf.strings.to_number(source_id)

# Pad groundtruth data for evaluation.
image_scale = input_processor.image_scale_to_original
boxes *= image_scale
is_crowds = tf.cast(is_crowds, dtype=tf.float32)
boxes = pad_to_fixed_size(boxes, -1, [self._max_instances_per_image, 4])
Expand All @@ -349,7 +396,7 @@ def dataset_parser(self, value, example_decoder, anchor_labeler, params):
[self._max_instances_per_image, 1])
if params['mixed_precision']:
dtype = (
tf.keras.mixed_precision.experimental.global_policy().compute_dtype)
tf.keras.mixed_precision.global_policy().compute_dtype)
image = tf.cast(image, dtype=dtype)
box_targets = tf.nest.map_structure(
lambda box_target: tf.cast(box_target, dtype=dtype), box_targets)
Expand Down Expand Up @@ -427,7 +474,7 @@ def _prefetch_dataset(filename):
return dataset

dataset = dataset.interleave(
_prefetch_dataset, num_parallel_calls=tf.data.experimental.AUTOTUNE)
_prefetch_dataset, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.with_options(self.dataset_options)
if self._is_training:
dataset = dataset.shuffle(64, seed=seed)
Expand All @@ -442,12 +489,12 @@ def _prefetch_dataset(filename):
anchor_labeler, params)
# pylint: enable=g-long-lambda
dataset = dataset.map(
map_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
map_fn, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.prefetch(batch_size)
dataset = dataset.batch(batch_size, drop_remainder=params['drop_remainder'])
dataset = dataset.map(
lambda *args: self.process_example(params, batch_size, *args))
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
if self._use_fake_data:
# Turn this dataset into a semi-fake dataset which always loop at the
# first batch. This reduces variance in performance and is useful in
Expand Down
2 changes: 1 addition & 1 deletion efficientdet/det_model_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def model_fn(inputs):

precision = utils.get_precision(params['strategy'], params['mixed_precision'])
cls_outputs, box_outputs = utils.build_model_with_precision(
precision, model_fn, features, params['is_training_bn'])
precision, model_fn, features)

levels = cls_outputs.keys()
for level in levels:
Expand Down
2 changes: 1 addition & 1 deletion efficientdet/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def model_arch(feats, model_name=None, **kwargs):
model_arch = det_model_fn.get_model_arch(model_name)

cls_outputs, box_outputs = utils.build_model_with_precision(
precision, model_arch, inputs, False, model_name, **kwargs)
precision, model_arch, inputs, model_name, **kwargs)

if mixed_precision:
# Post-processing has multiple places with hard-coded float32.
Expand Down
4 changes: 3 additions & 1 deletion efficientdet/keras/efficientdet_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
from keras import tfmot
from keras import util_keras
# pylint: disable=arguments-differ # fo keras layers.

utils.BatchNormalization = util_keras.get_batch_norm(tf.keras.layers.BatchNormalization)
utils.SyncBatchNormalization = util_keras.get_batch_norm(tf.keras.layers.experimental.SyncBatchNormalization)
utils.TpuBatchNormalization = util_keras.get_batch_norm(tf.keras.layers.experimental.SyncBatchNormalization)

def add_n(nodes):
"""A customized add_n to add up a list of tensors."""
Expand Down
6 changes: 3 additions & 3 deletions efficientdet/keras/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ def main(_):
config.override(FLAGS.hparams)

# Use 'mixed_float16' if running on GPUs.
policy = tf.keras.mixed_precision.experimental.Policy('float32')
tf.keras.mixed_precision.experimental.set_policy(policy)
tf.config.experimental_run_functions_eagerly(FLAGS.debug)
policy = tf.keras.mixed_precision.Policy('float32')
tf.keras.mixed_precision.set_global_policy(policy)
tf.config.run_functions_eagerly(FLAGS.debug)

# Create and run the model.
model = efficientdet_keras.EfficientDetModel(config=config)
Expand Down
4 changes: 2 additions & 2 deletions efficientdet/keras/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,8 @@ def __init__(self,
mixed_precision = self.params.get('mixed_precision', None)
precision = utils.get_precision(
self.params.get('strategy', None), mixed_precision)
policy = tf.keras.mixed_precision.experimental.Policy(precision)
tf.keras.mixed_precision.experimental.set_policy(policy)
policy = tf.keras.mixed_precision.Policy(precision)
tf.keras.mixed_precision.set_global_policy(policy)

@property
def model(self):
Expand Down
8 changes: 4 additions & 4 deletions efficientdet/keras/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def define_flags():
flags.DEFINE_integer('batch_size', 64, 'training batch size')
flags.DEFINE_integer('eval_samples', 5000, 'The number of samples for '
'evaluation.')
flags.DEFINE_integer('steps_per_execution', 200,
flags.DEFINE_integer('steps_per_execution', 1,
'Number of steps per training execution.')
flags.DEFINE_string(
'train_file_pattern', None,
Expand Down Expand Up @@ -163,7 +163,7 @@ def main(_):
tf.config.experimental.set_memory_growth(gpu, True)

if FLAGS.debug:
tf.config.experimental_run_functions_eagerly(True)
tf.config.run_functions_eagerly(True)
tf.debugging.set_log_device_placement(True)
os.environ['TF_DETERMINISTIC_OPS'] = '1'
tf.random.set_seed(FLAGS.tf_random_seed)
Expand Down Expand Up @@ -202,8 +202,8 @@ def main(_):
config.override(params, True)
# set mixed precision policy by keras api.
precision = utils.get_precision(config.strategy, config.mixed_precision)
policy = tf.keras.mixed_precision.experimental.Policy(precision)
tf.keras.mixed_precision.experimental.set_policy(policy)
policy = tf.keras.mixed_precision.Policy(precision)
tf.keras.mixed_precision.set_global_policy(policy)

def get_dataset(is_training, config):
file_pattern = (
Expand Down
14 changes: 7 additions & 7 deletions efficientdet/keras/train_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,10 +310,9 @@ def get_optimizer(params):
optimizer, average_decay=moving_average_decay, dynamic_decay=True)
precision = utils.get_precision(params['strategy'], params['mixed_precision'])
if precision == 'mixed_float16' and params['loss_scale']:
optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
optimizer = tf.keras.mixed_precision.LossScaleOptimizer(
optimizer,
loss_scale=tf.mixed_precision.experimental.DynamicLossScale(
params['loss_scale']))
initial_scale=params['loss_scale'])
return optimizer


Expand Down Expand Up @@ -777,17 +776,18 @@ def train_step(self, data):
loss_vals['reg_l2_loss'] = reg_l2_loss
total_loss += tf.cast(reg_l2_loss, loss_dtype)
if isinstance(self.optimizer,
tf.keras.mixed_precision.experimental.LossScaleOptimizer):
tf.keras.mixed_precision.LossScaleOptimizer):
scaled_loss = self.optimizer.get_scaled_loss(total_loss)
optimizer = self.optimizer.inner_optimizer
else:
scaled_loss = total_loss
optimizer = self.optimizer
loss_vals['loss'] = total_loss
loss_vals['learning_rate'] = self.optimizer.learning_rate(
self.optimizer.iterations)
loss_vals['learning_rate'] = optimizer.learning_rate(optimizer.iterations)
trainable_vars = self._freeze_vars()
scaled_gradients = tape.gradient(scaled_loss, trainable_vars)
if isinstance(self.optimizer,
tf.keras.mixed_precision.experimental.LossScaleOptimizer):
tf.keras.mixed_precision.LossScaleOptimizer):
gradients = self.optimizer.get_unscaled_gradients(scaled_gradients)
else:
gradients = scaled_gradients
Expand Down
7 changes: 7 additions & 0 deletions efficientdet/keras/util_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,10 @@ def fp16_to_fp32_nested(input_nested):
else:
return input_nested
return out_tensor_dict

def get_batch_norm(bn_class):
def _wrapper(*args, **kwargs):
if not kwargs.get('name', None):
kwargs['name'] = 'tpu_batch_normalization'
return bn_class(*args, **kwargs)
return _wrapper
39 changes: 38 additions & 1 deletion efficientdet/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,43 @@
from absl import flags
from absl import logging
import numpy as np

from tensorflow.python.ops import custom_gradient # pylint:disable=g-direct-tensorflow-import
from tensorflow.python.framework import ops # pylint:disable=g-direct-tensorflow-import


def get_variable_by_name(var_name):
"""Given a variable name, retrieves a handle on the tensorflow Variable."""

global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)

def _filter_fn(item):
try:
return var_name == item.op.name
except AttributeError:
# Collection items without operation are ignored.
return False

candidate_vars = list(filter(_filter_fn, global_vars))

if len(candidate_vars) >= 1:
# Filter out non-trainable variables.
candidate_vars = [v for v in candidate_vars if v.trainable]
else:
raise ValueError("Unsuccessful at finding variable {}.".format(var_name))

if len(candidate_vars) == 1:
return candidate_vars[0]
elif len(candidate_vars) > 1:
raise ValueError(
"Unsuccessful at finding trainable variable {}. "
"Number of candidates: {}. "
"Candidates: {}".format(var_name, len(candidate_vars), candidate_vars))
else:
# The variable is not trainable.
return None

custom_gradient.get_variable_by_name = get_variable_by_name
import tensorflow.compat.v1 as tf

import dataloader
Expand Down Expand Up @@ -355,7 +392,7 @@ def run_train_and_eval(e):
if p.exitcode != 0:
return p.exitcode
else:
tf.compat.v1.reset_default_graph()
tf.reset_default_graph()
run_train_and_eval(e)

else:
Expand Down
Loading

0 comments on commit 539ab65

Please sign in to comment.