diff --git a/opennmt/training.py b/opennmt/training.py index 6db89c06a..7fa63fb83 100644 --- a/opennmt/training.py +++ b/opennmt/training.py @@ -471,7 +471,15 @@ def _finalize_dataset(self, dataset): # We prefer not to use experimental_distribute_dataset here because it # sometimes fails to split the batches (noticed with tokens batch type). dataset_fn = dataset if callable(dataset) else lambda _: dataset - return self._strategy.experimental_distribute_datasets_from_function(dataset_fn) + # TODO: clean this API usage when TensorFlow requirement is updated to >=2.4. + distribute_fn = getattr( + self._strategy, "distribute_datasets_from_function", None + ) + if distribute_fn is None: + distribute_fn = ( + self._strategy.experimental_distribute_datasets_from_function + ) + return distribute_fn(dataset_fn) def _forward(self, source, target, accum_steps=1, report_steps=None): per_replica_loss = self._strategy.run(