From b34006f7d91c94616ecdc21705365a45e37afeba Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Fri, 6 Dec 2024 12:16:25 -0800 Subject: [PATCH] ... --- tests/recipes/test_ppo_full_finetune_single_device.py | 4 ++-- torchtune/training/checkpointing/_utils.py | 6 ++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/recipes/test_ppo_full_finetune_single_device.py b/tests/recipes/test_ppo_full_finetune_single_device.py index 9fc4dd84fe..36352cb0f1 100644 --- a/tests/recipes/test_ppo_full_finetune_single_device.py +++ b/tests/recipes/test_ppo_full_finetune_single_device.py @@ -239,7 +239,7 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch): ref_policy_checkpointer.checkpoint_files=[{policy_ckpt_path}]\ value_checkpointer.checkpoint_dir='{ckpt_dir}' \ - value_checkpointer.checkpoint_files=[{os.path.join(epoch_folder_minus_one, model_ckpt_fname)}]\ + value_checkpointer.checkpoint_files=[{os.path.join(value_tmpdir, epoch_folder_minus_one, model_ckpt_fname)}]\ value_checkpointer.output_dir={value_tmpdir} \ reward_checkpointer.checkpoint_dir='{ckpt_dir}' \ @@ -368,7 +368,7 @@ def test_training_state_on_resume_with_optimizer_in_bwd(self, tmpdir, monkeypatc ref_policy_checkpointer.checkpoint_files=[{policy_ckpt_path}]\ value_checkpointer.checkpoint_dir='{value_tmpdir}' \ - value_checkpointer.checkpoint_files=[{os.path.join(epoch_folder_minus_one, model_ckpt_fname)}]\ + value_checkpointer.checkpoint_files=[{os.path.join(value_tmpdir, epoch_folder_minus_one, model_ckpt_fname)}]\ value_checkpointer.output_dir={value_tmpdir} \ reward_checkpointer.checkpoint_dir='{ckpt_dir}' \ diff --git a/torchtune/training/checkpointing/_utils.py b/torchtune/training/checkpointing/_utils.py index 0f03f37635..82f60fc7d8 100644 --- a/torchtune/training/checkpointing/_utils.py +++ b/torchtune/training/checkpointing/_utils.py @@ -508,14 +508,12 @@ def get_model_checkpoint_path( def validate_checkpoint_files( checkpoint_files: Union[List[str]], - input_dir: Optional[Path] = None, + input_dir: Optional[Path], missing_ok=False, ) -> List[Path]: """ Validates that the checkpoint files exist and sorts based on ID. """ - if not input_dir: - input_dir = self._checkpoint_dir checkpoint_paths: List[Path] = [] for f in checkpoint_files: @@ -543,7 +541,7 @@ def validate_checkpoint_files( input_dir = checkpoint_dir # Case 2: Resuming from ckpt, but its full finetuning (no adapter) - elif has_adapter_checkpoint is None: + elif not has_adapter_checkpoint: input_dir = output_dir # Case 3: Resuming from ckpt and has an adapter.