Skip to content

Commit

Permalink
fix qat_lora_test (#2131)
Browse files Browse the repository at this point in the history
Co-authored-by: Felipe Mello <[email protected]>
  • Loading branch information
felipemello1 and Felipe Mello authored Dec 7, 2024
1 parent 26b2200 commit 06a8379
Showing 1 changed file with 24 additions and 7 deletions.
31 changes: 24 additions & 7 deletions tests/recipes/test_qat_lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@
TOKENIZER_PATHS,
)
from torchtune import config

from torchtune.training.checkpointing._utils import (
ADAPTER_MODEL_FNAME,
get_largest_iter_folder,
RECIPE_STATE_DIRNAME,
safe_torch_load,
SHARD_FNAME,
)
from torchtune.training.quantization import _torchao_0_7_supported


Expand Down Expand Up @@ -166,17 +174,19 @@ def test_training_state_on_resume(
runpy.run_path(TUNE_PATH, run_name="__main__")

# Resume training
epoch_folder = get_largest_iter_folder(tmpdir)
epoch_folder_minus_one = f"epoch_{int(epoch_folder.split('_')[-1]) - 1}"
cmd_2 = f"""
tune run --nnodes 1 --nproc_per_node 2 qat_lora_finetune_distributed \
--config {config} \
batch_size=4 \
gradient_accumulation_steps=1 \
output_dir={tmpdir} \
checkpointer._component_={ckpt_component} \
checkpointer.checkpoint_dir={tmpdir} \
checkpointer.checkpoint_dir={ckpt_dir} \
checkpointer.checkpoint_files=[{ckpt_path}]\
checkpointer.adapter_checkpoint={os.path.join(tmpdir, "adapter_0.pt")}
checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")}
checkpointer.adapter_checkpoint={os.path.join(epoch_folder_minus_one, f"{ADAPTER_MODEL_FNAME}.pt")}
checkpointer.recipe_checkpoint={os.path.join(RECIPE_STATE_DIRNAME, "recipe_state.pt")}
checkpointer.output_dir={tmpdir} \
checkpointer.model_type={model_type.upper()} \
tokenizer.path='{tokenizer_path}' \
Expand Down Expand Up @@ -254,17 +264,24 @@ def test_save_and_load_merged_weights(
model = config.instantiate(OmegaConf.from_dotlist(base_config).model)

# Load base model and trained adapter weights into LoRA model and call fwd
with open(f"{tmpdir}/adapter_1.pt", "rb") as f:
lora_sd = torch.load(f, weights_only=True)
epoch_folder = get_largest_iter_folder(tmpdir)
adpt_path = os.path.join(tmpdir, epoch_folder, f"{ADAPTER_MODEL_FNAME}.pt")
lora_sd = safe_torch_load(adpt_path, weights_only=True)

with open(ckpt_path, "rb") as f:
base_model_sd = torch.load(f, weights_only=True)
lora_model.load_state_dict(lora_sd, strict=False)
lora_model.load_state_dict(base_model_sd, strict=False)
baseline_out = lora_model(inputs)

# Load merged final ckpt directly into model and call fwd
with open(f"{tmpdir}/torchtune_model_1.pt", "rb") as f:
sd = torch.load(f, weights_only=True)
suffix = ".safetensors" if ckpt_type == "hf" else ".bin"
model_ckpt_fname = (
SHARD_FNAME.format(cpt_idx="1".zfill(5), num_shards="1".zfill(5)) + suffix
)
model_path = os.path.join(tmpdir, epoch_folder, model_ckpt_fname)
sd = safe_torch_load(model_path, weights_only=True)

model.load_state_dict(sd)
merged_ckpt_out = model(inputs)

Expand Down

0 comments on commit 06a8379

Please sign in to comment.