diff --git a/tests/recipes/test_qat_lora_finetune_distributed.py b/tests/recipes/test_qat_lora_finetune_distributed.py index 4d7c4b6899..6c43adcc73 100644 --- a/tests/recipes/test_qat_lora_finetune_distributed.py +++ b/tests/recipes/test_qat_lora_finetune_distributed.py @@ -27,6 +27,14 @@ TOKENIZER_PATHS, ) from torchtune import config + +from torchtune.training.checkpointing._utils import ( + ADAPTER_MODEL_FNAME, + get_largest_iter_folder, + RECIPE_STATE_DIRNAME, + safe_torch_load, + SHARD_FNAME, +) from torchtune.training.quantization import _torchao_0_7_supported @@ -166,6 +174,8 @@ def test_training_state_on_resume( runpy.run_path(TUNE_PATH, run_name="__main__") # Resume training + epoch_folder = get_largest_iter_folder(tmpdir) + epoch_folder_minus_one = f"epoch_{int(epoch_folder.split('_')[-1]) - 1}" cmd_2 = f""" tune run --nnodes 1 --nproc_per_node 2 qat_lora_finetune_distributed \ --config {config} \ @@ -173,10 +183,10 @@ def test_training_state_on_resume( gradient_accumulation_steps=1 \ output_dir={tmpdir} \ checkpointer._component_={ckpt_component} \ - checkpointer.checkpoint_dir={tmpdir} \ + checkpointer.checkpoint_dir={ckpt_dir} \ checkpointer.checkpoint_files=[{ckpt_path}]\ - checkpointer.adapter_checkpoint={os.path.join(tmpdir, "adapter_0.pt")} - checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")} + checkpointer.adapter_checkpoint={os.path.join(epoch_folder_minus_one, f"{ADAPTER_MODEL_FNAME}.pt")} + checkpointer.recipe_checkpoint={os.path.join(RECIPE_STATE_DIRNAME, "recipe_state.pt")} checkpointer.output_dir={tmpdir} \ checkpointer.model_type={model_type.upper()} \ tokenizer.path='{tokenizer_path}' \ @@ -254,8 +264,10 @@ def test_save_and_load_merged_weights( model = config.instantiate(OmegaConf.from_dotlist(base_config).model) # Load base model and trained adapter weights into LoRA model and call fwd - with open(f"{tmpdir}/adapter_1.pt", "rb") as f: - lora_sd = torch.load(f, weights_only=True) + epoch_folder = get_largest_iter_folder(tmpdir) + adpt_path = os.path.join(tmpdir, epoch_folder, f"{ADAPTER_MODEL_FNAME}.pt") + lora_sd = safe_torch_load(adpt_path, weights_only=True) + with open(ckpt_path, "rb") as f: base_model_sd = torch.load(f, weights_only=True) lora_model.load_state_dict(lora_sd, strict=False) @@ -263,8 +275,13 @@ def test_save_and_load_merged_weights( baseline_out = lora_model(inputs) # Load merged final ckpt directly into model and call fwd - with open(f"{tmpdir}/torchtune_model_1.pt", "rb") as f: - sd = torch.load(f, weights_only=True) + suffix = ".safetensors" if ckpt_type == "hf" else ".bin" + model_ckpt_fname = ( + SHARD_FNAME.format(cpt_idx="1".zfill(5), num_shards="1".zfill(5)) + suffix + ) + model_path = os.path.join(tmpdir, epoch_folder, model_ckpt_fname) + sd = safe_torch_load(model_path, weights_only=True) + model.load_state_dict(sd) merged_ckpt_out = model(inputs)