Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix FSDP meta initialization for wav2vec2 #817

Merged
merged 1 commit into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions src/fairseq2/nn/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,7 @@ def to_fsdp(
:param ignored_param_names:
The ignored parameter names. Can contain regular expressions.
:param skip_init:
If ``True``, skips initializing the parameters and buffers moved from
the meta device onto the device of ``gang``. Only relevant if ``module``
resides on the meta device.
Not used.
:param broadcast_state:
If ``True``, each FSDP module will broadcast its parameters and buffers
from rank 0 to ensure that they are replicated across all processes.
Expand Down Expand Up @@ -115,10 +113,18 @@ def to_fsdp(
if memory_policy is None:
memory_policy = FSDP_STANDARD_MEMORY_POLICY

if skip_init:
log.warning("`skip_init` parameter has no effect and will be removed in a future release.") # fmt: skip

param_init_fn = None

module_device = infer_device(module)
if module_device.type == "meta":
if gang.rank == 0:
skip_init = not broadcast_state
else:
skip_init = True

param_init_fn = FSDPParameterInitializer(gang.device, skip_init)

if mixed_precision_dtype is None:
Expand Down
1 change: 0 additions & 1 deletion src/fairseq2/recipes/lm/instruction_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,6 @@ def load_instruction_finetuner(
dp_gang,
config.data_parallelism,
log,
fsdp_skip_init=True,
fsdp_broadcast_state=not has_checkpoint,
fsdp_reshard_after_forward=config.fsdp_reshard_after_forward,
fsdp_mixed_precision_dtype=config.dtype if config.mixed_precision else None,
Expand Down
1 change: 0 additions & 1 deletion src/fairseq2/recipes/lm/preference_finetune/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,6 @@ def load_preference_finetuner(
dp_gang,
config.data_parallelism,
log,
fsdp_skip_init=True,
fsdp_broadcast_state=not has_checkpoint,
fsdp_reshard_after_forward=config.fsdp_reshard_after_forward,
fsdp_mixed_precision_dtype=config.dtype if config.mixed_precision else None,
Expand Down
1 change: 0 additions & 1 deletion src/fairseq2/recipes/mt/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,6 @@ def load_mt_trainer(config: MTTrainConfig, output_dir: Path) -> Trainer[Seq2SeqB
gang,
config.data_parallelism,
log,
fsdp_skip_init=has_checkpoint,
fsdp_broadcast_state=not has_checkpoint,
fsdp_mixed_precision_dtype=config.dtype,
fsdp_fp32_reduce=True,
Expand Down
1 change: 0 additions & 1 deletion src/fairseq2/recipes/wav2vec2/asr/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,6 @@ def load_wav2vec2_asr_trainer(
config.data_parallelism,
log,
ddp_find_unused_parameters=config.freeze_encoder_for_n_steps > 0,
fsdp_skip_init=True,
fsdp_broadcast_state=not has_checkpoint,
fsdp_mixed_precision_dtype=config.dtype,
fsdp_fp32_reduce=True,
Expand Down
7 changes: 2 additions & 5 deletions src/fairseq2/recipes/wav2vec2/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from fairseq2.logging import get_log_writer
from fairseq2.models import create_model
from fairseq2.models.sequence import SequenceBatch
from fairseq2.models.wav2vec2 import Wav2Vec2Config, Wav2Vec2Model
from fairseq2.models.wav2vec2 import Wav2Vec2Model
from fairseq2.optim import AdamWConfig, create_optimizer
from fairseq2.optim.lr_scheduler import PolynomialDecayLRConfig, create_lr_scheduler
from fairseq2.recipes.trainer import AbstractTrainUnit, Trainer
Expand Down Expand Up @@ -186,9 +186,7 @@ class Wav2Vec2TrainConfig:
def _base_960h() -> Wav2Vec2TrainConfig:
config = Wav2Vec2TrainConfig()

assert isinstance(config.model_config, Wav2Vec2Config)

config.model_config.encoder_config.first_pass_dropout_p = 0.1
config.model_config = {"encoder_config": {"first_pass_dropout_p": 0.1}}

return config

Expand Down Expand Up @@ -284,7 +282,6 @@ def load_wav2vec2_trainer(
gang,
config.data_parallelism,
log,
fsdp_skip_init=True,
fsdp_broadcast_state=not has_checkpoint,
fsdp_mixed_precision_dtype=config.dtype,
fsdp_fp32_reduce=True,
Expand Down
Loading