Skip to content

Commit

Permalink
fairseq_pretraining starting from checkpoint (#255)
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreasPlt authored Jan 7, 2025
1 parent 666eb24 commit 4a1af01
Showing 1 changed file with 5 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def get_fairseq_root(commit="e4a2e4e93efbcbaaae52a17ae6600beb2083fb33", fairseq_
return fairseq_root


def run_fairseq_pretraining(exp_name, commit, python_exe_hash_overwrite=None, **kwargs):
def run_fairseq_pretraining(exp_name, commit, python_exe_hash_overwrite=None, checkpoint=None, **kwargs):
"""
Runs a FairseqHydraTrainingJob to pretrain a wav2vec 2.0 model.
Expand All @@ -73,6 +73,8 @@ def run_fairseq_pretraining(exp_name, commit, python_exe_hash_overwrite=None, **
python_exe_hash_overwrite (Optional[str]): The hash overwrite for the fairseq_python_exe to use.
It should only be used to achieve compatibility with the previous setup structure and should be ignored
in all other cases.
checkpoint (Optional[tk.Path]): The path to the checkpoint to start from. If None, the training will start
from scratch.
**kwargs: Additional arguments to pass to the job. These will be used to overwrite the model configuration.
"""
# job requirements
Expand All @@ -93,6 +95,8 @@ def run_fairseq_pretraining(exp_name, commit, python_exe_hash_overwrite=None, **
# generate config
fairseq_args = get_fairseq_args(num_gpus=num_gpus)
fairseq_args["task"]["alignment"] = alignment
if checkpoint is not None:
fairseq_args["checkpoint"]["continue_once"] = checkpoint
for k, v in kwargs.items():
fairseq_args["model"][k] = v
fairseq_config = FairseqHydraConfig(fairseq_args)
Expand Down

0 comments on commit 4a1af01

Please sign in to comment.