diff --git a/src/fairseq2/optim/dynamic_loss_scaler.py b/src/fairseq2/optim/dynamic_loss_scaler.py index cc8c628e3..71703cad5 100644 --- a/src/fairseq2/optim/dynamic_loss_scaler.py +++ b/src/fairseq2/optim/dynamic_loss_scaler.py @@ -32,7 +32,7 @@ class DynamicLossScaler: _optimizer: Optimizer _scale_window: int _min_scale: float - _is_enabled: bool + _enabled: bool # compat: consolidate into `GradScaler` once we cease support for PT2.2 _grad_scaler: GradScaler | ShardedGradScaler @@ -125,7 +125,7 @@ def __init__( self._optimizer = optimizer self._scale_window = scale_window self._min_scale = min_scale - self._is_enabled = enabled + self._enabled = enabled def state_dict(self) -> dict[str, Any]: return {"grad_scaler": self._grad_scaler.state_dict()} @@ -194,7 +194,7 @@ def _are_close(a: float, b: float) -> bool: def unscale_gradients_(self) -> None: """Unscale the associated optimizer's gradients by the current scale.""" - if not supports_manual_gradient_scaling(self._optimizer): + if self._enabled and not supports_manual_gradient_scaling(self._optimizer): raise RuntimeError( "`optimizer` must support manual gradient scaling via `torch.cuda.amp.GradScaler`, but supports only implicit scaling in its step function (i.e. `_step_supports_amp_scaling == True`)." ) @@ -212,7 +212,7 @@ def get_scale(self) -> float: @property def is_enabled(self) -> bool: """``True`` if the loss scaling is enabled.""" - return self._is_enabled + return self._enabled @final diff --git a/src/fairseq2/recipes/evaluator.py b/src/fairseq2/recipes/evaluator.py index 37e0fde06..2d5f85a19 100644 --- a/src/fairseq2/recipes/evaluator.py +++ b/src/fairseq2/recipes/evaluator.py @@ -8,6 +8,7 @@ from abc import ABC, abstractmethod from collections.abc import Sequence +from contextlib import AbstractContextManager, nullcontext from itertools import count from pathlib import Path from typing import Generic, TypeVar, final @@ -29,7 +30,7 @@ ) from fairseq2.recipes.common_metrics import extend_batch_metrics from fairseq2.recipes.utils.cli import create_rich_progress -from fairseq2.typing import CPU +from fairseq2.typing import CPU, DataType from fairseq2.utils.profiler import Stopwatch from fairseq2.utils.rng import RngBag @@ -104,6 +105,8 @@ class Evaluator(Generic[BatchT]): _root_gang: Gang _dp_gang: Gang _tp_gang: Gang + _dtype: DataType + _amp: bool _metric_recorders: list[MetricRecorder] _seed: int _wall_watch: Stopwatch @@ -118,6 +121,8 @@ def __init__( wall_watch: Stopwatch, dp_gang: Gang | None = None, tp_gang: Gang | None = None, + dtype: DataType = torch.float32, + amp: bool = False, tb_dir: Path | None = None, metrics_dir: Path | None = None, seed: int = 2, @@ -135,6 +140,10 @@ def __init__( The data parallel gang. If ``None``, ``root_gang`` will be used. :param tp_gang: The tensor parallel gang. Only required for tensor parallel models. + :param dtype: + The data type of the model. + :param amp: + If ``True``, enables ``torch.amp``. :param tb_dir: The TensorBoard log directory to dump metrics. :param metrics_dir: @@ -168,6 +177,10 @@ def __init__( f"The coordinator process of `root_gang` (i.e. rank 0) must be rank 0 in `dp_gang` and `tp_gang`, but is {self._dp_gang.rank} and {self._tp_gang.rank} instead." ) + self._dtype = dtype + + self._amp = amp + if root_gang.rank == 0: self._metric_recorders = [LogMetricRecorder(log)] @@ -239,12 +252,19 @@ def _evaluate_unit( break for batch in batches: - unit(batch) + with self._maybe_autocast(): + unit(batch) num_effective_batches += 1 self._publish_metrics(unit, num_effective_batches, watch.get_elapsed_time()) + def _maybe_autocast(self) -> AbstractContextManager[None]: + if self._dtype == torch.float32 or not self._amp: + return nullcontext() + + return torch.autocast(device_type=self._dp_gang.device.type, dtype=self._dtype) + def _publish_metrics( self, unit: EvalUnit[BatchT], num_batches: int, elapsed_time: float ) -> None: diff --git a/src/fairseq2/recipes/generator.py b/src/fairseq2/recipes/generator.py index af1500561..5270a52ff 100644 --- a/src/fairseq2/recipes/generator.py +++ b/src/fairseq2/recipes/generator.py @@ -7,6 +7,7 @@ from __future__ import annotations from abc import ABC, abstractmethod +from contextlib import AbstractContextManager, nullcontext from itertools import count from pathlib import Path from typing import Generic, TypeVar, final @@ -27,7 +28,7 @@ ) from fairseq2.recipes.common_metrics import extend_batch_metrics from fairseq2.recipes.utils.cli import create_rich_progress -from fairseq2.typing import CPU +from fairseq2.typing import CPU, DataType from fairseq2.utils.profiler import Stopwatch from fairseq2.utils.rng import RngBag @@ -79,6 +80,8 @@ class Generator(Generic[BatchT]): _root_gang: Gang _dp_gang: Gang _tp_gang: Gang + _dtype: DataType + _amp: bool _metric_recorders: list[MetricRecorder] _seed: int _wall_watch: Stopwatch @@ -93,6 +96,8 @@ def __init__( wall_watch: Stopwatch, dp_gang: Gang | None = None, tp_gang: Gang | None = None, + dtype: DataType = torch.float32, + amp: bool = False, metrics_dir: Path | None = None, seed: int = 2, ) -> None: @@ -109,6 +114,10 @@ def __init__( The data parallel gang. If ``None``, ``gang`` will be used. :param tp_gang: The tensor parallel gang. Only required for tensor parallel models. + :param dtype: + The data type of the model. + :param amp: + If ``True``, enables ``torch.amp``. :param metrics_dir: The directory to dump metrics. :param seed: @@ -135,6 +144,10 @@ def __init__( f"The coordinator process of `root_gang` (i.e. rank 0) must be rank 0 in `dp_gang` and `tp_gang`, but is {self._dp_gang.rank} and {self._tp_gang.rank} instead." ) + self._dtype = dtype + + self._amp = amp + if root_gang.rank == 0: self._metric_recorders = [LogMetricRecorder(log)] @@ -194,12 +207,19 @@ def _do_run(self) -> None: break for batch in batches: - self._unit(batch) + with self._maybe_autocast(): + self._unit(batch) num_effective_batches += 1 self._publish_metrics(num_effective_batches, watch.get_elapsed_time()) + def _maybe_autocast(self) -> AbstractContextManager[None]: + if self._dtype == torch.float32 or not self._amp: + return nullcontext() + + return torch.autocast(device_type=self._dp_gang.device.type, dtype=self._dtype) + def _publish_metrics(self, num_batches: int, elapsed_time: float) -> None: log.debug("Syncing metrics.") diff --git a/src/fairseq2/recipes/lm/instruction_finetune.py b/src/fairseq2/recipes/lm/instruction_finetune.py index f30f86ba5..1fe91dd66 100644 --- a/src/fairseq2/recipes/lm/instruction_finetune.py +++ b/src/fairseq2/recipes/lm/instruction_finetune.py @@ -71,7 +71,7 @@ class InstructionFinetuneConfig: train_split: str = "default" """The name of the train data split.""" - valid_split: str = "valid" + valid_split: str | None = None """The name of the valid data split.""" max_seq_len: int = 8192 @@ -80,7 +80,7 @@ class InstructionFinetuneConfig: max_num_tokens: int = 8192 * 2 """The maximum number of tokens per batch.""" - max_num_valid_tokens: int = 8192 * 2 + max_num_valid_tokens: int | None = None """The maximum number of tokens per validation batch.""" example_shuffle_window: int = 10_000 @@ -105,8 +105,14 @@ class InstructionFinetuneConfig: dtype: DataType = torch.bfloat16 """The data type of the model.""" - mixed_precision: bool = True - """If ``True``, the model will be trained in mixed precision.""" + mixed_precision: Literal["none", "static", "dynamic"] = "static" + """ + If 'none', the whole training will be run in `dtype`. If 'static', forward + and backward passes will be run in `dtype`, but the optimizer step will be + run in full precision. If 'dynamic', forward and backward passes will be run + with `torch.amp` in `dtype`, but the optimizer step will be run in full + precision. + """ data_parallelism: Literal["ddp", "fsdp"] = "fsdp" """The data parallelism API to use.""" @@ -161,9 +167,6 @@ class InstructionFinetuneConfig: max_num_data_epochs: int | None = None """The maximum number of data epochs to train for.""" - validate: bool = False - """If ``True``, runs validation.""" - validate_after_n_steps: int = 0 """The number of steps after which to start validating the model.""" @@ -323,7 +326,7 @@ def load_instruction_finetuner( init_device = META - dtype = torch.float32 if config.mixed_precision else config.dtype + dtype = config.dtype if config.mixed_precision == "none" else torch.float32 has_checkpoint = checkpoint_manager.has_checkpoint() @@ -372,6 +375,8 @@ def load_instruction_finetuner( checkpoint_manager.save_model_metadata(base_asset=model_card.name) + mp_dtype = config.dtype if config.mixed_precision == "static" else None + dp_model = to_data_parallel( model, dp_gang, @@ -379,7 +384,7 @@ def load_instruction_finetuner( log, fsdp_broadcast_state=not has_checkpoint, fsdp_reshard_after_forward=config.fsdp_reshard_after_forward, - fsdp_mixed_precision_dtype=config.dtype if config.mixed_precision else None, + fsdp_mixed_precision_dtype=mp_dtype, fsdp_fp32_reduce=True, fsdp_wrap_granularity=config.fsdp_wrap_granularity, ) @@ -447,16 +452,18 @@ def load_instruction_finetuner( ) from ex # Initialize the validation unit. - if config.validate: + if config.valid_split is not None: valid_unit = InstructionValidUnit(criterion, dp_gang) + max_num_tokens = config.max_num_valid_tokens or config.max_num_tokens + try: valid_data_reader = dataset.create_reader( config.valid_split, tokenizer, dp_gang, config.max_seq_len, - batching=LengthBatching(config.max_num_valid_tokens), + batching=LengthBatching(max_num_tokens), example_shuffle_window=config.example_shuffle_window, batch_shuffle_window=config.batch_shuffle_window, sync_mode="until_last", @@ -479,6 +486,12 @@ def load_instruction_finetuner( seed += 1 + # TODO: Fix once we support static mixed precision on one device. + if config.mixed_precision == "static": + amp = root_gang.size == 1 or config.data_parallelism != "fsdp" + else: + amp = config.mixed_precision == "dynamic" + # Initialize the trainer. return Trainer[SequenceBatch]( unit=unit, @@ -489,9 +502,9 @@ def load_instruction_finetuner( dtype=config.dtype, optimizer=optimizer, lr_scheduler=lr_scheduler, - amp=config.mixed_precision, fp16_loss_scale=config.fp16_loss_scale, max_gradient_norm=config.max_gradient_norm, + amp=amp, max_num_steps=config.max_num_steps, max_num_data_epochs=config.max_num_data_epochs, valid_units=valid_units, diff --git a/src/fairseq2/recipes/lm/preference_finetune/recipe.py b/src/fairseq2/recipes/lm/preference_finetune/recipe.py index f99b35252..e7ff79638 100644 --- a/src/fairseq2/recipes/lm/preference_finetune/recipe.py +++ b/src/fairseq2/recipes/lm/preference_finetune/recipe.py @@ -91,8 +91,14 @@ class PreferenceOptimizationConfig: dtype: DataType = torch.bfloat16 """The data type of the model.""" - mixed_precision: bool = True - """If ``True``, the model will be trained in mixed precision.""" + mixed_precision: Literal["none", "static", "dynamic"] = "static" + """ + If 'none', the whole training will be run in `dtype`. If 'static', forward + and backward passes will be run in `dtype`, but the optimizer step will be + run in full precision. If 'dynamic', forward and backward passes will be run + with `torch.amp` in `dtype`, but the optimizer step will be run in full + precision. + """ data_parallelism: Literal["ddp", "fsdp"] = "fsdp" """The data parallelism API to use.""" @@ -289,7 +295,7 @@ def load_preference_finetuner( init_device = META - dtype = torch.float32 if config.mixed_precision else config.dtype + dtype = config.dtype if config.mixed_precision == "none" else torch.float32 has_checkpoint = checkpoint_manager.has_checkpoint() @@ -333,6 +339,8 @@ def load_preference_finetuner( checkpoint_manager.save_model_metadata(base_asset=model_card.name) + mp_dtype = config.dtype if config.mixed_precision == "static" else None + dp_model = to_data_parallel( model, dp_gang, @@ -340,7 +348,7 @@ def load_preference_finetuner( log, fsdp_broadcast_state=not has_checkpoint, fsdp_reshard_after_forward=config.fsdp_reshard_after_forward, - fsdp_mixed_precision_dtype=config.dtype if config.mixed_precision else None, + fsdp_mixed_precision_dtype=mp_dtype, fsdp_fp32_reduce=True, fsdp_wrap_granularity=config.fsdp_wrap_granularity, ) @@ -422,6 +430,12 @@ def load_preference_finetuner( "The learning rate scheduler cannot be created. See nested exception for details." ) from ex + # TODO: Fix once we support static mixed precision on one device. + if config.mixed_precision == "static": + amp = root_gang.size == 1 or config.data_parallelism != "fsdp" + else: + amp = config.mixed_precision == "dynamic" + # Initialize the trainer. return Trainer[PreferenceOptimizationBatch]( unit=unit, @@ -432,9 +446,9 @@ def load_preference_finetuner( dtype=config.dtype, optimizer=optimizer, lr_scheduler=lr_scheduler, - amp=config.mixed_precision, fp16_loss_scale=config.fp16_loss_scale, max_gradient_norm=config.max_gradient_norm, + amp=amp, max_num_steps=config.max_num_steps, max_num_data_epochs=config.max_num_data_epochs, checkpoint_manager=checkpoint_manager, diff --git a/src/fairseq2/recipes/lm/text_generate.py b/src/fairseq2/recipes/lm/text_generate.py index 188c59499..ca99890cf 100644 --- a/src/fairseq2/recipes/lm/text_generate.py +++ b/src/fairseq2/recipes/lm/text_generate.py @@ -75,6 +75,9 @@ class TextGenerateConfig: dtype: DataType = torch.bfloat16 """The data type of the model.""" + amp: bool = False + """If ``True``, runs evaluation with ``torch.amp``.""" + tensor_parallel_size: int = 1 """The size of tensor parallelism.""" @@ -301,6 +304,8 @@ def load_text_generator( root_gang=root_gang, dp_gang=dp_gang, tp_gang=tp_gang, + dtype=config.dtype, + amp=config.amp, metrics_dir=output_dir.joinpath("metrics"), seed=seed, wall_watch=wall_watch, diff --git a/src/fairseq2/recipes/mt/eval.py b/src/fairseq2/recipes/mt/eval.py index a06c89a0f..ef6236b3b 100644 --- a/src/fairseq2/recipes/mt/eval.py +++ b/src/fairseq2/recipes/mt/eval.py @@ -81,6 +81,9 @@ class MTEvalConfig: dtype: DataType = torch.float16 """The data type of the model.""" + amp: bool = False + """If ``True``, runs evaluation with ``torch.amp``.""" + # Loss label_smoothing: float = 0.1 """The amount of label smoothing to apply while computing the loss.""" @@ -308,6 +311,8 @@ def load_mt_evaluator( units=units, data_readers=data_readers, root_gang=gang, + dtype=config.dtype, + amp=config.amp, tb_dir=output_dir.joinpath("tb"), metrics_dir=output_dir.joinpath("metrics"), seed=seed, diff --git a/src/fairseq2/recipes/mt/train.py b/src/fairseq2/recipes/mt/train.py index 787e02a2e..d6b0ab01e 100644 --- a/src/fairseq2/recipes/mt/train.py +++ b/src/fairseq2/recipes/mt/train.py @@ -424,6 +424,9 @@ def load_mt_trainer(config: MTTrainConfig, output_dir: Path) -> Trainer[Seq2SeqB valid_data_readers.append(valid_data_reader) + # TODO: Fix once we support static mixed precision on one device. + amp = gang.size == 1 or config.data_parallelism != "fsdp" + # Initialize the trainer. return Trainer[Seq2SeqBatch]( unit=unit, @@ -434,6 +437,7 @@ def load_mt_trainer(config: MTTrainConfig, output_dir: Path) -> Trainer[Seq2SeqB lr_scheduler=lr_scheduler, fp16_loss_scale=config.fp16_loss_scale, max_gradient_norm=config.max_gradient_norm, + amp=amp, max_num_steps=config.max_num_steps, max_num_data_epochs=config.max_num_data_epochs, score_metric_name="chrf" if config.compute_bleu_chrf else None, diff --git a/src/fairseq2/recipes/mt/translate.py b/src/fairseq2/recipes/mt/translate.py index 48815f823..ed788cf3d 100644 --- a/src/fairseq2/recipes/mt/translate.py +++ b/src/fairseq2/recipes/mt/translate.py @@ -78,6 +78,9 @@ class TextTranslateConfig: dtype: DataType = torch.float16 """The data type of the model.""" + amp: bool = False + """If ``True``, runs evaluation with ``torch.amp``.""" + # Generation generator: str = "beam_search" """The sequence generator.""" @@ -245,6 +248,8 @@ def load_text_translator( unit=unit, data_reader=data_reader, root_gang=gang, + dtype=config.dtype, + amp=config.amp, metrics_dir=output_dir.joinpath("metrics"), seed=seed, wall_watch=wall_watch, diff --git a/src/fairseq2/recipes/trainer.py b/src/fairseq2/recipes/trainer.py index 71ca17e77..3491e15e1 100644 --- a/src/fairseq2/recipes/trainer.py +++ b/src/fairseq2/recipes/trainer.py @@ -124,6 +124,7 @@ class Trainer(StatefulObjectBag, Generic[BatchT]): _lr_scheduler: LRScheduler _loss_scaler: DynamicLossScaler _max_gradient_norm: float | None + _amp: bool _step_nr: int _max_num_steps: int | None _data_epoch_nr: int @@ -179,13 +180,13 @@ def __init__( optimizer: Optimizer, checkpoint_manager: CheckpointManager, wall_watch: Stopwatch, - dtype: DataType = torch.float32, dp_gang: Gang | None = None, tp_gang: Gang | None = None, + dtype: DataType = torch.float32, lr_scheduler: LRScheduler | None = None, - amp: bool = True, fp16_loss_scale: tuple[float, float] = (128.0, 0.0001), max_gradient_norm: float | None = None, + amp: bool = False, max_num_steps: int | None = None, max_num_data_epochs: int | None = None, score_metric_name: str | None = None, @@ -229,7 +230,7 @@ def __init__( :param wall_watch: The stopwatch to track process wall-time. :param dtype: - The data type to train with. + The data type of the model. :param dp_gang: The data parallel gang. If ``None``, ``gang`` will be used. :param tp_gang: @@ -237,9 +238,7 @@ def __init__( :param lr_scheduler: The learning rate scheduler. :param amp: - If ``True``, enables automatic mixed precision (i.e. ``torch.amp``). - If the model is trained with mixed precision DDP or FSDP, it takes - precedence over this setting and no autocast will be applied. + If ``True``, enables ``torch.amp``. :param fp16_loss_scale: The initial and minimum loss scale for fp16 training. :param max_gradient_norm: @@ -345,8 +344,6 @@ def __init__( self._lr_scheduler = lr_scheduler or NoopLR(optimizer) - self._amp = amp - fp16_init_scale, fp16_min_scale = fp16_loss_scale self._loss_scaler = DynamicLossScaler( @@ -361,6 +358,8 @@ def __init__( self._max_gradient_norm = max_gradient_norm + self._amp = amp + self.register_stateful("_step_nr", 0) if max_num_steps == 0: @@ -832,11 +831,6 @@ def _maybe_autocast(self) -> AbstractContextManager[None]: if self._dtype == torch.float32 or not self._amp: return nullcontext() - if self._model.training and isinstance(self._model, (DDP, FSDP)): - mp = self._model.mixed_precision - if mp is not None and mp.param_dtype is not None: - return nullcontext() - return torch.autocast(device_type=self._dp_gang.device.type, dtype=self._dtype) def _should_publish_metrics(self) -> bool: @@ -895,9 +889,9 @@ def _should_validate(self) -> bool: def _validate(self) -> None: log.info("Starting validation after step {}.", self._step_nr) - with summon_fsdp_for_validation(self._model): - self._model.eval() + self._model.eval() + with summon_fsdp_for_validation(self._model): unit_scores = [] for unit, data_reader in zip(self._valid_units, self._valid_data_readers): @@ -910,7 +904,7 @@ def _validate(self) -> None: self._valid_score = self._compute_valid_score(unit_scores) - self._model.train() + self._model.train() log.info("Validation complete.") diff --git a/src/fairseq2/recipes/utils/sweep.py b/src/fairseq2/recipes/utils/sweep.py index e1149c0d2..5db10cf74 100644 --- a/src/fairseq2/recipes/utils/sweep.py +++ b/src/fairseq2/recipes/utils/sweep.py @@ -74,6 +74,7 @@ def __init__(self, *, allow_set: set[object] | None = None) -> None: "max_num_steps", "max_num_tokens", "max_seq_len", + "mixed_precision", "model", "model_arch", "model_config", diff --git a/src/fairseq2/recipes/wav2vec2/asr/eval.py b/src/fairseq2/recipes/wav2vec2/asr/eval.py index 6017e0971..2aea92c49 100644 --- a/src/fairseq2/recipes/wav2vec2/asr/eval.py +++ b/src/fairseq2/recipes/wav2vec2/asr/eval.py @@ -80,6 +80,9 @@ class Wav2Vec2AsrEvalConfig: dtype: DataType = torch.float16 """The data type of the model.""" + amp: bool = False + """If ``True``, runs evaluation with ``torch.amp``.""" + # Misc seed: int = 2 """The random number generator seed to use.""" @@ -231,6 +234,8 @@ def load_wav2vec2_asr_evaluator( units=[unit], data_readers=[data_reader], root_gang=gang, + dtype=config.dtype, + amp=config.amp, tb_dir=output_dir.joinpath("tb"), metrics_dir=output_dir.joinpath("metrics"), seed=seed, diff --git a/src/fairseq2/recipes/wav2vec2/asr/train.py b/src/fairseq2/recipes/wav2vec2/asr/train.py index 7bf4bfc40..b5550b837 100644 --- a/src/fairseq2/recipes/wav2vec2/asr/train.py +++ b/src/fairseq2/recipes/wav2vec2/asr/train.py @@ -465,6 +465,9 @@ def load_wav2vec2_asr_trainer( seed += 1 + # TODO: Fix once we support static mixed precision on one device. + amp = gang.size == 1 or config.data_parallelism != "fsdp" + # Initialize the trainer. return Trainer[Seq2SeqBatch]( unit=unit, @@ -475,6 +478,7 @@ def load_wav2vec2_asr_trainer( lr_scheduler=lr_scheduler, fp16_loss_scale=config.fp16_loss_scale, max_gradient_norm=config.max_gradient_norm, + amp=amp, max_num_steps=config.max_num_steps, max_num_data_epochs=config.max_num_data_epochs, score_metric_name="wer", diff --git a/src/fairseq2/recipes/wav2vec2/eval.py b/src/fairseq2/recipes/wav2vec2/eval.py index 179b6caa3..08ccdf61d 100644 --- a/src/fairseq2/recipes/wav2vec2/eval.py +++ b/src/fairseq2/recipes/wav2vec2/eval.py @@ -75,6 +75,9 @@ class Wav2Vec2EvalConfig: dtype: DataType = torch.float16 """The data type of the model.""" + amp: bool = False + """If ``True``, runs evaluation with ``torch.amp``.""" + # Loss diversity_loss_weight: float = 0.1 """The weight of the diversity loss.""" @@ -199,6 +202,8 @@ def load_wav2vec2_evaluator( units=[unit], data_readers=[data_reader], root_gang=gang, + dtype=config.dtype, + amp=config.amp, tb_dir=output_dir.joinpath("tb"), metrics_dir=output_dir.joinpath("metrics"), seed=seed, diff --git a/src/fairseq2/recipes/wav2vec2/train.py b/src/fairseq2/recipes/wav2vec2/train.py index 6391a5f92..124b88e11 100644 --- a/src/fairseq2/recipes/wav2vec2/train.py +++ b/src/fairseq2/recipes/wav2vec2/train.py @@ -368,6 +368,9 @@ def load_wav2vec2_trainer( seed += 1 + # TODO: Fix once we support static mixed precision on one device. + amp = gang.size == 1 or config.data_parallelism != "fsdp" + # Initialize the trainer. return Trainer[SequenceBatch]( unit=unit, @@ -378,6 +381,7 @@ def load_wav2vec2_trainer( lr_scheduler=lr_scheduler, fp16_loss_scale=config.fp16_loss_scale, max_gradient_norm=config.max_gradient_norm, + amp=amp, max_num_steps=config.max_num_steps, max_num_data_epochs=config.max_num_data_epochs, score_metric_name="loss",