Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Recipe State Dict Code #1964

Merged
merged 3 commits into from
Nov 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/api_ref_modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ PEFT Components
peft.AdapterModule
peft.get_adapter_params
peft.set_trainable_params
peft.get_adapter_state_dict
peft.validate_missing_and_unexpected_for_lora
peft.validate_state_dict_for_lora
peft.disable_adapter
Expand Down
1 change: 1 addition & 0 deletions docs/source/api_ref_training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ Utilities for enabling and working with distributed training.
get_world_size_and_rank
get_full_finetune_fsdp_wrap_policy
lora_fsdp_wrap_policy
gather_cpu_state_dict

.. _ac_label:

Expand Down
4 changes: 2 additions & 2 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,8 +645,8 @@ def save_checkpoint(

# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
cpu_state_dict = training.get_full_model_state_dict(
self._model,
cpu_state_dict = training.gather_cpu_state_dict(
self._model.state_dict(),
self._is_rank_zero,
device=self._device,
)
Expand Down
10 changes: 4 additions & 6 deletions recipes/knowledge_distillation_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from torchtune.modules.peft import (
DoRALinear,
get_adapter_params,
get_adapter_state_dict,
get_lora_module_names,
get_merged_lora_ckpt,
load_dora_magnitudes,
Expand Down Expand Up @@ -707,8 +708,8 @@ def save_checkpoint(self, epoch: int) -> None:
intermediate_checkpoint = epoch + 1 < self.total_epochs
# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
cpu_state_dict = training.get_full_model_state_dict(
self._model,
cpu_state_dict = training.gather_cpu_state_dict(
self._model.state_dict(),
self._is_rank_zero,
device=self._device,
)
Expand All @@ -728,10 +729,7 @@ def save_checkpoint(self, epoch: int) -> None:

# Filter out the adapter keys and weights from the model state dict. These will
# be saved separately
adapter_key_filter = lambda x: x in self.adapter_params
adapter_state_dict = {
k: v for k, v in cpu_state_dict.items() if adapter_key_filter(k)
}
adapter_state_dict = get_adapter_state_dict(cpu_state_dict)
checkpoint_dict.update({training.ADAPTER_KEY: adapter_state_dict})

# merge the adapter weights and base weights to create the model checkpoint
Expand Down
6 changes: 2 additions & 4 deletions recipes/knowledge_distillation_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from torchtune.datasets import ConcatDataset
from torchtune.modules.peft import (
get_adapter_params,
get_adapter_state_dict,
get_lora_module_names,
get_merged_lora_ckpt,
load_dora_magnitudes,
Expand Down Expand Up @@ -586,10 +587,7 @@ def save_checkpoint(self, epoch: int) -> None:
ckpt_dict.update({training.MODEL_KEY: merged_state_dict})

# Construct the adapter weights
adapter_key_filter = lambda x: x in self.adapter_params
adapter_state_dict = {
k: v for k, v in self._model.state_dict().items() if adapter_key_filter(k)
}
adapter_state_dict = get_adapter_state_dict(self._model.state_dict())
ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict})
adapter_config = {
"r": self._lora_rank,
Expand Down
29 changes: 16 additions & 13 deletions recipes/lora_dpo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
disable_adapter,
DoRALinear,
get_adapter_params,
get_adapter_state_dict,
get_merged_lora_ckpt,
load_dora_magnitudes,
LoRALinear,
Expand Down Expand Up @@ -504,8 +505,12 @@ def save_checkpoint(
intermediate_checkpoint = epoch + 1 < self.total_epochs
# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
cpu_state_dict = training.get_full_model_state_dict(
self._model,
state_dict = self._model.state_dict()
if self._save_adapter_weights_only:
state_dict = get_adapter_state_dict(state_dict, device=None)

cpu_state_dict = training.gather_cpu_state_dict(
state_dict,
self._is_rank_zero,
device=self._device,
)
Expand All @@ -521,23 +526,21 @@ def save_checkpoint(
# Now that we have the model and opt state dict, create the actual checkpoint dict
# to be sent to the checkpointer and ultimately written to file
if self._is_rank_zero:

# Filter out the adapter keys and weights from the model state dict. These will
# be saved separately
adapter_key_filter = lambda x: x in self.adapter_params
adapter_state_dict = {
k: v for k, v in cpu_state_dict.items() if adapter_key_filter(k)
}
checkpoint_dict.update({training.ADAPTER_KEY: adapter_state_dict})

# merge the adapter weights and base weights to create the model checkpoint
if not self._save_adapter_weights_only:
if self._save_adapter_weights_only:
adapter_state_dict = cpu_state_dict
else:
# Filter out the adapter keys and weights from the model state dict. These will
# be saved separately
adapter_state_dict = get_adapter_state_dict(cpu_state_dict)

# merge the adapter weights and base weights to create the model checkpoint
merged_state_dict = get_merged_lora_ckpt(
cpu_state_dict,
rank=self._lora_rank,
alpha=self._lora_alpha,
)
checkpoint_dict.update({training.MODEL_KEY: merged_state_dict})
checkpoint_dict.update({training.ADAPTER_KEY: adapter_state_dict})

# if training is in-progress, checkpoint the optimizer state and recipe state
# as well.
Expand Down
3 changes: 2 additions & 1 deletion recipes/lora_dpo_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from torchtune.modules.peft import (
disable_adapter,
get_adapter_params,
get_adapter_state_dict,
get_merged_lora_ckpt,
set_trainable_params,
validate_missing_and_unexpected_for_lora,
Expand Down Expand Up @@ -407,7 +408,7 @@ def save_checkpoint(self, epoch: int) -> None:
}
)

adapter_state_dict = {k: v.cpu() for k, v in self.adapter_params.items()}
adapter_state_dict = get_adapter_state_dict(self._model.state_dict())
ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict})
if not self._save_adapter_weights_only:
# Construct the full state dict with LoRA weights merged into base LLM weights
Expand Down
31 changes: 17 additions & 14 deletions recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from torchtune.modules.peft import (
DoRALinear,
get_adapter_params,
get_adapter_state_dict,
get_lora_module_names,
get_merged_lora_ckpt,
load_dora_magnitudes,
Expand Down Expand Up @@ -452,8 +453,7 @@ def _setup_model(
with training.set_default_dtype(self._dtype), torch.device("meta"):
model = config.instantiate(cfg_model)

self.adapter_params = get_adapter_params(model)
set_trainable_params(model, self.adapter_params)
set_trainable_params(model, get_adapter_params(model))

if self._compile:
training.compile_model(model, verbose=self._is_rank_zero)
Expand Down Expand Up @@ -664,11 +664,14 @@ def save_checkpoint(

# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
cpu_state_dict = training.get_full_model_state_dict(
self._model,
state_dict = self._model.state_dict()
if self._save_adapter_weights_only:
state_dict = get_adapter_state_dict(state_dict, device=None)

cpu_state_dict = training.gather_cpu_state_dict(
state_dict,
self._is_rank_zero,
device=self._device,
trainable_only=self._save_adapter_weights_only,
)
if self._is_rank_zero:
log.info(
Expand All @@ -694,22 +697,22 @@ def save_checkpoint(
# to be sent to the checkpointer and ultimately written to file
if self._is_rank_zero:
start = time.perf_counter()
# Filter out the adapter keys and weights from the model state dict. These will
# be saved separately
adapter_key_filter = lambda x: x in self.adapter_params
adapter_state_dict = {
k: v for k, v in cpu_state_dict.items() if adapter_key_filter(k)
}
checkpoint_dict.update({training.ADAPTER_KEY: adapter_state_dict})

# merge the adapter weights and base weights to create the model checkpoint
if not self._save_adapter_weights_only:
if self._save_adapter_weights_only:
adapter_state_dict = cpu_state_dict
else:
# Filter out the adapter keys and weights from the model state dict. These will
# be saved separately
adapter_state_dict = get_adapter_state_dict(cpu_state_dict)

# merge the adapter weights and base weights to create the model checkpoint
merged_state_dict = get_merged_lora_ckpt(
cpu_state_dict,
rank=self._lora_rank,
alpha=self._lora_alpha,
)
checkpoint_dict.update({training.MODEL_KEY: merged_state_dict})
checkpoint_dict.update({training.ADAPTER_KEY: adapter_state_dict})

# if training is in-progress, checkpoint the optimizer state and recipe state
# as well.
Expand Down
3 changes: 2 additions & 1 deletion recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from torchtune.datasets import ConcatDataset
from torchtune.modules.peft import (
get_adapter_params,
get_adapter_state_dict,
get_lora_module_names,
get_merged_lora_ckpt,
load_dora_magnitudes,
Expand Down Expand Up @@ -587,7 +588,7 @@ def save_checkpoint(self, epoch: int) -> None:
}
)

adapter_state_dict = {k: v.cpu() for k, v in self.adapter_params.items()}
adapter_state_dict = get_adapter_state_dict(self._model.state_dict())
ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict})

if not self._save_adapter_weights_only:
Expand Down
4 changes: 2 additions & 2 deletions recipes/qat_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,8 +555,8 @@ def save_checkpoint(
intermediate_checkpoint = epoch + 1 < self.total_epochs
# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
cpu_state_dict = training.get_full_model_state_dict(
self._model,
cpu_state_dict = training.gather_cpu_state_dict(
self._model.state_dict(),
self._is_rank_zero,
)

Expand Down
88 changes: 87 additions & 1 deletion tests/recipes/test_full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os
import runpy

import sys
from pathlib import Path

Expand Down Expand Up @@ -113,3 +113,89 @@ def test_loss(
torch.testing.assert_close(
loss_values, expected_loss_values, rtol=1e-4, atol=1e-4
)

@pytest.mark.integration_test
@pytest.mark.parametrize(
"config, model_type, ckpt_type, micro_batch_size, gradient_accumulation_steps, optim_in_bwd",
[
("llama3/8B_full", "llama3", "tune", 1, 4, False),
],
)
@gpu_test(gpu_count=2)
def test_training_state_on_resume(
self,
micro_batch_size,
gradient_accumulation_steps,
config,
model_type,
ckpt_type,
optim_in_bwd,
tmpdir,
monkeypatch,
):
ckpt_component = CKPT_COMPONENT_MAP[ckpt_type]
ckpt = model_type + "_" + ckpt_type
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
tokenizer_path = Path(TOKENIZER_PATHS[model_type])
ckpt_dir = ckpt_path.parent
log_file = gen_log_file_name(tmpdir)

# Config file needed for model conversion.
# Create a second copy for training resume
write_hf_ckpt_config(ckpt_dir)
write_hf_ckpt_config(tmpdir)

# Train for two epochs
cmd_1 = f"""
tune run --nnodes 1 --nproc_per_node 2 full_finetune_distributed \
--config {config} \
batch_size={micro_batch_size} \
gradient_accumulation_steps={gradient_accumulation_steps} \
output_dir={tmpdir} \
checkpointer._component_={ckpt_component} \
checkpointer.checkpoint_dir='{ckpt_dir}' \
checkpointer.checkpoint_files=[{ckpt_path}]\
checkpointer.output_dir={tmpdir} \
checkpointer.model_type={model_type.upper()} \
tokenizer.path='{tokenizer_path}' \
tokenizer.prompt_template=null \
clip_grad_norm=100 \
""".split()

model_config = MODEL_TEST_CONFIGS[model_type]
cmd_1 = cmd_1 + self._get_test_config_overrides() + model_config

monkeypatch.setattr(sys, "argv", cmd_1)
runpy.run_path(TUNE_PATH, run_name="__main__")

# Resume training
cmd_2 = f"""
tune run --nnodes 1 --nproc_per_node 2 full_finetune_distributed \
--config {config} \
batch_size={micro_batch_size} \
gradient_accumulation_steps={gradient_accumulation_steps} \
output_dir={tmpdir} \
checkpointer._component_={ckpt_component} \
checkpointer.checkpoint_dir='{tmpdir}' \
checkpointer.checkpoint_files=[{os.path.join(tmpdir, "torchtune_model_0.pt")}]\
checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")}\
checkpointer.output_dir={tmpdir} \
checkpointer.model_type={model_type.upper()} \
tokenizer.path='{tokenizer_path}' \
tokenizer.prompt_template=null \
resume_from_checkpoint=True \
metric_logger.filename={log_file} \
clip_grad_norm=100 \
""".split()

cmd_2 = cmd_2 + self._get_test_config_overrides() + model_config

monkeypatch.setattr(sys, "argv", cmd_2)
runpy.run_path(TUNE_PATH, run_name="__main__")

expected_loss_values = self._fetch_expected_loss_values(model_type)[2:]

loss_values = get_loss_values_from_metric_logger(log_file)
torch.testing.assert_close(
loss_values, expected_loss_values, rtol=1e-4, atol=1e-4
)
2 changes: 1 addition & 1 deletion tests/recipes/test_full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch):
checkpointer._component_=torchtune.training.FullModelHFCheckpointer \
checkpointer.checkpoint_dir={tmpdir} \
checkpointer.checkpoint_files=[{os.path.join(tmpdir, "hf_model_0001_0.pt")}]\
checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")}
checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")}\
checkpointer.output_dir={tmpdir} \
checkpointer.model_type=LLAMA2 \
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
Expand Down
Loading
Loading