Skip to content

Commit

Permalink
...
Browse files Browse the repository at this point in the history
  • Loading branch information
Felipe Mello committed Dec 6, 2024
1 parent 15498c7 commit b34006f
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 6 deletions.
4 changes: 2 additions & 2 deletions tests/recipes/test_ppo_full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch):
ref_policy_checkpointer.checkpoint_files=[{policy_ckpt_path}]\
value_checkpointer.checkpoint_dir='{ckpt_dir}' \
value_checkpointer.checkpoint_files=[{os.path.join(epoch_folder_minus_one, model_ckpt_fname)}]\
value_checkpointer.checkpoint_files=[{os.path.join(value_tmpdir, epoch_folder_minus_one, model_ckpt_fname)}]\
value_checkpointer.output_dir={value_tmpdir} \
reward_checkpointer.checkpoint_dir='{ckpt_dir}' \
Expand Down Expand Up @@ -368,7 +368,7 @@ def test_training_state_on_resume_with_optimizer_in_bwd(self, tmpdir, monkeypatc
ref_policy_checkpointer.checkpoint_files=[{policy_ckpt_path}]\
value_checkpointer.checkpoint_dir='{value_tmpdir}' \
value_checkpointer.checkpoint_files=[{os.path.join(epoch_folder_minus_one, model_ckpt_fname)}]\
value_checkpointer.checkpoint_files=[{os.path.join(value_tmpdir, epoch_folder_minus_one, model_ckpt_fname)}]\
value_checkpointer.output_dir={value_tmpdir} \
reward_checkpointer.checkpoint_dir='{ckpt_dir}' \
Expand Down
6 changes: 2 additions & 4 deletions torchtune/training/checkpointing/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,14 +508,12 @@ def get_model_checkpoint_path(

def validate_checkpoint_files(
checkpoint_files: Union[List[str]],
input_dir: Optional[Path] = None,
input_dir: Optional[Path],
missing_ok=False,
) -> List[Path]:
"""
Validates that the checkpoint files exist and sorts based on ID.
"""
if not input_dir:
input_dir = self._checkpoint_dir

checkpoint_paths: List[Path] = []
for f in checkpoint_files:
Expand Down Expand Up @@ -543,7 +541,7 @@ def validate_checkpoint_files(
input_dir = checkpoint_dir

# Case 2: Resuming from ckpt, but its full finetuning (no adapter)
elif has_adapter_checkpoint is None:
elif not has_adapter_checkpoint:
input_dir = output_dir

# Case 3: Resuming from ckpt and has an adapter.
Expand Down

0 comments on commit b34006f

Please sign in to comment.