Skip to content

Commit

Permalink
Fix FSDP meta initialization for wav2vec2
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalioglu committed Sep 27, 2024
1 parent 7089eb4 commit 74cf260
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 12 deletions.
12 changes: 9 additions & 3 deletions src/fairseq2/nn/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,7 @@ def to_fsdp(
:param ignored_param_names:
The ignored parameter names. Can contain regular expressions.
:param skip_init:
If ``True``, skips initializing the parameters and buffers moved from
the meta device onto the device of ``gang``. Only relevant if ``module``
resides on the meta device.
Not used.
:param broadcast_state:
If ``True``, each FSDP module will broadcast its parameters and buffers
from rank 0 to ensure that they are replicated across all processes.
Expand Down Expand Up @@ -115,10 +113,18 @@ def to_fsdp(
if memory_policy is None:
memory_policy = FSDP_STANDARD_MEMORY_POLICY

if skip_init:
log.warning("`skip_init` parameter has no effect and will be removed in a future release.") # fmt: skip

param_init_fn = None

module_device = infer_device(module)
if module_device.type == "meta":
if gang.rank == 0:
skip_init = not broadcast_state
else:
skip_init = True

param_init_fn = FSDPParameterInitializer(gang.device, skip_init)

if mixed_precision_dtype is None:
Expand Down
1 change: 0 additions & 1 deletion src/fairseq2/recipes/lm/instruction_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,6 @@ def load_instruction_finetuner(
dp_gang,
config.data_parallelism,
log,
fsdp_skip_init=True,
fsdp_broadcast_state=not has_checkpoint,
fsdp_reshard_after_forward=config.fsdp_reshard_after_forward,
fsdp_mixed_precision_dtype=config.dtype if config.mixed_precision else None,
Expand Down
1 change: 0 additions & 1 deletion src/fairseq2/recipes/lm/preference_finetune/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,6 @@ def load_preference_finetuner(
dp_gang,
config.data_parallelism,
log,
fsdp_skip_init=True,
fsdp_broadcast_state=not has_checkpoint,
fsdp_reshard_after_forward=config.fsdp_reshard_after_forward,
fsdp_mixed_precision_dtype=config.dtype if config.mixed_precision else None,
Expand Down
1 change: 0 additions & 1 deletion src/fairseq2/recipes/mt/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,6 @@ def load_mt_trainer(config: MTTrainConfig, output_dir: Path) -> Trainer[Seq2SeqB
gang,
config.data_parallelism,
log,
fsdp_skip_init=has_checkpoint,
fsdp_broadcast_state=not has_checkpoint,
fsdp_mixed_precision_dtype=config.dtype,
fsdp_fp32_reduce=True,
Expand Down
1 change: 0 additions & 1 deletion src/fairseq2/recipes/wav2vec2/asr/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,6 @@ def load_wav2vec2_asr_trainer(
config.data_parallelism,
log,
ddp_find_unused_parameters=config.freeze_encoder_for_n_steps > 0,
fsdp_skip_init=True,
fsdp_broadcast_state=not has_checkpoint,
fsdp_mixed_precision_dtype=config.dtype,
fsdp_fp32_reduce=True,
Expand Down
7 changes: 2 additions & 5 deletions src/fairseq2/recipes/wav2vec2/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from fairseq2.logging import get_log_writer
from fairseq2.models import create_model
from fairseq2.models.sequence import SequenceBatch
from fairseq2.models.wav2vec2 import Wav2Vec2Config, Wav2Vec2Model
from fairseq2.models.wav2vec2 import Wav2Vec2Model
from fairseq2.optim import AdamWConfig, create_optimizer
from fairseq2.optim.lr_scheduler import PolynomialDecayLRConfig, create_lr_scheduler
from fairseq2.recipes.trainer import AbstractTrainUnit, Trainer
Expand Down Expand Up @@ -186,9 +186,7 @@ class Wav2Vec2TrainConfig:
def _base_960h() -> Wav2Vec2TrainConfig:
config = Wav2Vec2TrainConfig()

assert isinstance(config.model_config, Wav2Vec2Config)

config.model_config.encoder_config.first_pass_dropout_p = 0.1
config.model_config = {"encoder_config": {"first_pass_dropout_p": 0.1}}

return config

Expand Down Expand Up @@ -284,7 +282,6 @@ def load_wav2vec2_trainer(
gang,
config.data_parallelism,
log,
fsdp_skip_init=True,
fsdp_broadcast_state=not has_checkpoint,
fsdp_mixed_precision_dtype=config.dtype,
fsdp_fp32_reduce=True,
Expand Down

0 comments on commit 74cf260

Please sign in to comment.