From 74cf2609c9fbd624092aa6c09e0ff2a65175b9c4 Mon Sep 17 00:00:00 2001 From: Can Balioglu Date: Fri, 27 Sep 2024 20:16:56 +0000 Subject: [PATCH] Fix FSDP meta initialization for wav2vec2 --- src/fairseq2/nn/fsdp.py | 12 +++++++++--- src/fairseq2/recipes/lm/instruction_finetune.py | 1 - .../recipes/lm/preference_finetune/recipe.py | 1 - src/fairseq2/recipes/mt/train.py | 1 - src/fairseq2/recipes/wav2vec2/asr/train.py | 1 - src/fairseq2/recipes/wav2vec2/train.py | 7 ++----- 6 files changed, 11 insertions(+), 12 deletions(-) diff --git a/src/fairseq2/nn/fsdp.py b/src/fairseq2/nn/fsdp.py index 124ab24c9..28d094d50 100644 --- a/src/fairseq2/nn/fsdp.py +++ b/src/fairseq2/nn/fsdp.py @@ -64,9 +64,7 @@ def to_fsdp( :param ignored_param_names: The ignored parameter names. Can contain regular expressions. :param skip_init: - If ``True``, skips initializing the parameters and buffers moved from - the meta device onto the device of ``gang``. Only relevant if ``module`` - resides on the meta device. + Not used. :param broadcast_state: If ``True``, each FSDP module will broadcast its parameters and buffers from rank 0 to ensure that they are replicated across all processes. @@ -115,10 +113,18 @@ def to_fsdp( if memory_policy is None: memory_policy = FSDP_STANDARD_MEMORY_POLICY + if skip_init: + log.warning("`skip_init` parameter has no effect and will be removed in a future release.") # fmt: skip + param_init_fn = None module_device = infer_device(module) if module_device.type == "meta": + if gang.rank == 0: + skip_init = not broadcast_state + else: + skip_init = True + param_init_fn = FSDPParameterInitializer(gang.device, skip_init) if mixed_precision_dtype is None: diff --git a/src/fairseq2/recipes/lm/instruction_finetune.py b/src/fairseq2/recipes/lm/instruction_finetune.py index 8043ab5f2..dfa8e9c21 100644 --- a/src/fairseq2/recipes/lm/instruction_finetune.py +++ b/src/fairseq2/recipes/lm/instruction_finetune.py @@ -357,7 +357,6 @@ def load_instruction_finetuner( dp_gang, config.data_parallelism, log, - fsdp_skip_init=True, fsdp_broadcast_state=not has_checkpoint, fsdp_reshard_after_forward=config.fsdp_reshard_after_forward, fsdp_mixed_precision_dtype=config.dtype if config.mixed_precision else None, diff --git a/src/fairseq2/recipes/lm/preference_finetune/recipe.py b/src/fairseq2/recipes/lm/preference_finetune/recipe.py index 618ad2660..f99b35252 100644 --- a/src/fairseq2/recipes/lm/preference_finetune/recipe.py +++ b/src/fairseq2/recipes/lm/preference_finetune/recipe.py @@ -338,7 +338,6 @@ def load_preference_finetuner( dp_gang, config.data_parallelism, log, - fsdp_skip_init=True, fsdp_broadcast_state=not has_checkpoint, fsdp_reshard_after_forward=config.fsdp_reshard_after_forward, fsdp_mixed_precision_dtype=config.dtype if config.mixed_precision else None, diff --git a/src/fairseq2/recipes/mt/train.py b/src/fairseq2/recipes/mt/train.py index 4353b6b4e..787e02a2e 100644 --- a/src/fairseq2/recipes/mt/train.py +++ b/src/fairseq2/recipes/mt/train.py @@ -292,7 +292,6 @@ def load_mt_trainer(config: MTTrainConfig, output_dir: Path) -> Trainer[Seq2SeqB gang, config.data_parallelism, log, - fsdp_skip_init=has_checkpoint, fsdp_broadcast_state=not has_checkpoint, fsdp_mixed_precision_dtype=config.dtype, fsdp_fp32_reduce=True, diff --git a/src/fairseq2/recipes/wav2vec2/asr/train.py b/src/fairseq2/recipes/wav2vec2/asr/train.py index 2ed3157e2..63872485b 100644 --- a/src/fairseq2/recipes/wav2vec2/asr/train.py +++ b/src/fairseq2/recipes/wav2vec2/asr/train.py @@ -371,7 +371,6 @@ def load_wav2vec2_asr_trainer( config.data_parallelism, log, ddp_find_unused_parameters=config.freeze_encoder_for_n_steps > 0, - fsdp_skip_init=True, fsdp_broadcast_state=not has_checkpoint, fsdp_mixed_precision_dtype=config.dtype, fsdp_fp32_reduce=True, diff --git a/src/fairseq2/recipes/wav2vec2/train.py b/src/fairseq2/recipes/wav2vec2/train.py index 49f958cdb..3918ef817 100644 --- a/src/fairseq2/recipes/wav2vec2/train.py +++ b/src/fairseq2/recipes/wav2vec2/train.py @@ -23,7 +23,7 @@ from fairseq2.logging import get_log_writer from fairseq2.models import create_model from fairseq2.models.sequence import SequenceBatch -from fairseq2.models.wav2vec2 import Wav2Vec2Config, Wav2Vec2Model +from fairseq2.models.wav2vec2 import Wav2Vec2Model from fairseq2.optim import AdamWConfig, create_optimizer from fairseq2.optim.lr_scheduler import PolynomialDecayLRConfig, create_lr_scheduler from fairseq2.recipes.trainer import AbstractTrainUnit, Trainer @@ -186,9 +186,7 @@ class Wav2Vec2TrainConfig: def _base_960h() -> Wav2Vec2TrainConfig: config = Wav2Vec2TrainConfig() - assert isinstance(config.model_config, Wav2Vec2Config) - - config.model_config.encoder_config.first_pass_dropout_p = 0.1 + config.model_config = {"encoder_config": {"first_pass_dropout_p": 0.1}} return config @@ -284,7 +282,6 @@ def load_wav2vec2_trainer( gang, config.data_parallelism, log, - fsdp_skip_init=True, fsdp_broadcast_state=not has_checkpoint, fsdp_mixed_precision_dtype=config.dtype, fsdp_fp32_reduce=True,