Skip to content

Commit

Permalink
Refactor Recipe State Dict Code (#1964)
Browse files Browse the repository at this point in the history
  • Loading branch information
pbontrager authored Nov 9, 2024
1 parent 550163b commit 08efaed
Show file tree
Hide file tree
Showing 18 changed files with 226 additions and 129 deletions.
1 change: 1 addition & 0 deletions docs/source/api_ref_modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ PEFT Components
peft.AdapterModule
peft.get_adapter_params
peft.set_trainable_params
peft.get_adapter_state_dict
peft.validate_missing_and_unexpected_for_lora
peft.validate_state_dict_for_lora
peft.disable_adapter
Expand Down
1 change: 1 addition & 0 deletions docs/source/api_ref_training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ Utilities for enabling and working with distributed training.
get_world_size_and_rank
get_full_finetune_fsdp_wrap_policy
lora_fsdp_wrap_policy
gather_cpu_state_dict

.. _ac_label:

Expand Down
4 changes: 2 additions & 2 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,8 +645,8 @@ def save_checkpoint(

# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
cpu_state_dict = training.get_full_model_state_dict(
self._model,
cpu_state_dict = training.gather_cpu_state_dict(
self._model.state_dict(),
self._is_rank_zero,
device=self._device,
)
Expand Down
10 changes: 4 additions & 6 deletions recipes/knowledge_distillation_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from torchtune.modules.peft import (
DoRALinear,
get_adapter_params,
get_adapter_state_dict,
get_lora_module_names,
get_merged_lora_ckpt,
load_dora_magnitudes,
Expand Down Expand Up @@ -707,8 +708,8 @@ def save_checkpoint(self, epoch: int) -> None:
intermediate_checkpoint = epoch + 1 < self.total_epochs
# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
cpu_state_dict = training.get_full_model_state_dict(
self._model,
cpu_state_dict = training.gather_cpu_state_dict(
self._model.state_dict(),
self._is_rank_zero,
device=self._device,
)
Expand All @@ -728,10 +729,7 @@ def save_checkpoint(self, epoch: int) -> None:

# Filter out the adapter keys and weights from the model state dict. These will
# be saved separately
adapter_key_filter = lambda x: x in self.adapter_params
adapter_state_dict = {
k: v for k, v in cpu_state_dict.items() if adapter_key_filter(k)
}
adapter_state_dict = get_adapter_state_dict(cpu_state_dict)
checkpoint_dict.update({training.ADAPTER_KEY: adapter_state_dict})

# merge the adapter weights and base weights to create the model checkpoint
Expand Down
6 changes: 2 additions & 4 deletions recipes/knowledge_distillation_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from torchtune.datasets import ConcatDataset
from torchtune.modules.peft import (
get_adapter_params,
get_adapter_state_dict,
get_lora_module_names,
get_merged_lora_ckpt,
load_dora_magnitudes,
Expand Down Expand Up @@ -586,10 +587,7 @@ def save_checkpoint(self, epoch: int) -> None:
ckpt_dict.update({training.MODEL_KEY: merged_state_dict})

# Construct the adapter weights
adapter_key_filter = lambda x: x in self.adapter_params
adapter_state_dict = {
k: v for k, v in self._model.state_dict().items() if adapter_key_filter(k)
}
adapter_state_dict = get_adapter_state_dict(self._model.state_dict())
ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict})
adapter_config = {
"r": self._lora_rank,
Expand Down
29 changes: 16 additions & 13 deletions recipes/lora_dpo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
disable_adapter,
DoRALinear,
get_adapter_params,
get_adapter_state_dict,
get_merged_lora_ckpt,
load_dora_magnitudes,
LoRALinear,
Expand Down Expand Up @@ -504,8 +505,12 @@ def save_checkpoint(
intermediate_checkpoint = epoch + 1 < self.total_epochs
# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
cpu_state_dict = training.get_full_model_state_dict(
self._model,
state_dict = self._model.state_dict()
if self._save_adapter_weights_only:
state_dict = get_adapter_state_dict(state_dict, device=None)

cpu_state_dict = training.gather_cpu_state_dict(
state_dict,
self._is_rank_zero,
device=self._device,
)
Expand All @@ -521,23 +526,21 @@ def save_checkpoint(
# Now that we have the model and opt state dict, create the actual checkpoint dict
# to be sent to the checkpointer and ultimately written to file
if self._is_rank_zero:

# Filter out the adapter keys and weights from the model state dict. These will
# be saved separately
adapter_key_filter = lambda x: x in self.adapter_params
adapter_state_dict = {
k: v for k, v in cpu_state_dict.items() if adapter_key_filter(k)
}
checkpoint_dict.update({training.ADAPTER_KEY: adapter_state_dict})

# merge the adapter weights and base weights to create the model checkpoint
if not self._save_adapter_weights_only:
if self._save_adapter_weights_only:
adapter_state_dict = cpu_state_dict
else:
# Filter out the adapter keys and weights from the model state dict. These will
# be saved separately
adapter_state_dict = get_adapter_state_dict(cpu_state_dict)

# merge the adapter weights and base weights to create the model checkpoint
merged_state_dict = get_merged_lora_ckpt(
cpu_state_dict,
rank=self._lora_rank,
alpha=self._lora_alpha,
)
checkpoint_dict.update({training.MODEL_KEY: merged_state_dict})
checkpoint_dict.update({training.ADAPTER_KEY: adapter_state_dict})

# if training is in-progress, checkpoint the optimizer state and recipe state
# as well.
Expand Down
3 changes: 2 additions & 1 deletion recipes/lora_dpo_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from torchtune.modules.peft import (
disable_adapter,
get_adapter_params,
get_adapter_state_dict,
get_merged_lora_ckpt,
set_trainable_params,
validate_missing_and_unexpected_for_lora,
Expand Down Expand Up @@ -407,7 +408,7 @@ def save_checkpoint(self, epoch: int) -> None:
}
)

adapter_state_dict = {k: v.cpu() for k, v in self.adapter_params.items()}
adapter_state_dict = get_adapter_state_dict(self._model.state_dict())
ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict})
if not self._save_adapter_weights_only:
# Construct the full state dict with LoRA weights merged into base LLM weights
Expand Down
31 changes: 17 additions & 14 deletions recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from torchtune.modules.peft import (
DoRALinear,
get_adapter_params,
get_adapter_state_dict,
get_lora_module_names,
get_merged_lora_ckpt,
load_dora_magnitudes,
Expand Down Expand Up @@ -452,8 +453,7 @@ def _setup_model(
with training.set_default_dtype(self._dtype), torch.device("meta"):
model = config.instantiate(cfg_model)

self.adapter_params = get_adapter_params(model)
set_trainable_params(model, self.adapter_params)
set_trainable_params(model, get_adapter_params(model))

if self._compile:
training.compile_model(model, verbose=self._is_rank_zero)
Expand Down Expand Up @@ -664,11 +664,14 @@ def save_checkpoint(

# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
cpu_state_dict = training.get_full_model_state_dict(
self._model,
state_dict = self._model.state_dict()
if self._save_adapter_weights_only:
state_dict = get_adapter_state_dict(state_dict, device=None)

cpu_state_dict = training.gather_cpu_state_dict(
state_dict,
self._is_rank_zero,
device=self._device,
trainable_only=self._save_adapter_weights_only,
)
if self._is_rank_zero:
log.info(
Expand All @@ -694,22 +697,22 @@ def save_checkpoint(
# to be sent to the checkpointer and ultimately written to file
if self._is_rank_zero:
start = time.perf_counter()
# Filter out the adapter keys and weights from the model state dict. These will
# be saved separately
adapter_key_filter = lambda x: x in self.adapter_params
adapter_state_dict = {
k: v for k, v in cpu_state_dict.items() if adapter_key_filter(k)
}
checkpoint_dict.update({training.ADAPTER_KEY: adapter_state_dict})

# merge the adapter weights and base weights to create the model checkpoint
if not self._save_adapter_weights_only:
if self._save_adapter_weights_only:
adapter_state_dict = cpu_state_dict
else:
# Filter out the adapter keys and weights from the model state dict. These will
# be saved separately
adapter_state_dict = get_adapter_state_dict(cpu_state_dict)

# merge the adapter weights and base weights to create the model checkpoint
merged_state_dict = get_merged_lora_ckpt(
cpu_state_dict,
rank=self._lora_rank,
alpha=self._lora_alpha,
)
checkpoint_dict.update({training.MODEL_KEY: merged_state_dict})
checkpoint_dict.update({training.ADAPTER_KEY: adapter_state_dict})

# if training is in-progress, checkpoint the optimizer state and recipe state
# as well.
Expand Down
3 changes: 2 additions & 1 deletion recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from torchtune.datasets import ConcatDataset
from torchtune.modules.peft import (
get_adapter_params,
get_adapter_state_dict,
get_lora_module_names,
get_merged_lora_ckpt,
load_dora_magnitudes,
Expand Down Expand Up @@ -592,7 +593,7 @@ def save_checkpoint(self, epoch: int) -> None:
}
)

adapter_state_dict = {k: v.cpu() for k, v in self.adapter_params.items()}
adapter_state_dict = get_adapter_state_dict(self._model.state_dict())
ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict})

if not self._save_adapter_weights_only:
Expand Down
4 changes: 2 additions & 2 deletions recipes/qat_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,8 +673,8 @@ def save_checkpoint(

# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
cpu_state_dict = training.get_full_model_state_dict(
self._model,
cpu_state_dict = training.gather_cpu_state_dict(
self._model.state_dict(),
self._is_rank_zero,
device=self._device,
)
Expand Down
88 changes: 87 additions & 1 deletion tests/recipes/test_full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os
import runpy

import sys
from pathlib import Path

Expand Down Expand Up @@ -113,3 +113,89 @@ def test_loss(
torch.testing.assert_close(
loss_values, expected_loss_values, rtol=1e-4, atol=1e-4
)

@pytest.mark.integration_test
@pytest.mark.parametrize(
"config, model_type, ckpt_type, micro_batch_size, gradient_accumulation_steps, optim_in_bwd",
[
("llama3/8B_full", "llama3", "tune", 1, 4, False),
],
)
@gpu_test(gpu_count=2)
def test_training_state_on_resume(
self,
micro_batch_size,
gradient_accumulation_steps,
config,
model_type,
ckpt_type,
optim_in_bwd,
tmpdir,
monkeypatch,
):
ckpt_component = CKPT_COMPONENT_MAP[ckpt_type]
ckpt = model_type + "_" + ckpt_type
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
tokenizer_path = Path(TOKENIZER_PATHS[model_type])
ckpt_dir = ckpt_path.parent
log_file = gen_log_file_name(tmpdir)

# Config file needed for model conversion.
# Create a second copy for training resume
write_hf_ckpt_config(ckpt_dir)
write_hf_ckpt_config(tmpdir)

# Train for two epochs
cmd_1 = f"""
tune run --nnodes 1 --nproc_per_node 2 full_finetune_distributed \
--config {config} \
batch_size={micro_batch_size} \
gradient_accumulation_steps={gradient_accumulation_steps} \
output_dir={tmpdir} \
checkpointer._component_={ckpt_component} \
checkpointer.checkpoint_dir='{ckpt_dir}' \
checkpointer.checkpoint_files=[{ckpt_path}]\
checkpointer.output_dir={tmpdir} \
checkpointer.model_type={model_type.upper()} \
tokenizer.path='{tokenizer_path}' \
tokenizer.prompt_template=null \
clip_grad_norm=100 \
""".split()

model_config = MODEL_TEST_CONFIGS[model_type]
cmd_1 = cmd_1 + self._get_test_config_overrides() + model_config

monkeypatch.setattr(sys, "argv", cmd_1)
runpy.run_path(TUNE_PATH, run_name="__main__")

# Resume training
cmd_2 = f"""
tune run --nnodes 1 --nproc_per_node 2 full_finetune_distributed \
--config {config} \
batch_size={micro_batch_size} \
gradient_accumulation_steps={gradient_accumulation_steps} \
output_dir={tmpdir} \
checkpointer._component_={ckpt_component} \
checkpointer.checkpoint_dir='{tmpdir}' \
checkpointer.checkpoint_files=[{os.path.join(tmpdir, "torchtune_model_0.pt")}]\
checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")}\
checkpointer.output_dir={tmpdir} \
checkpointer.model_type={model_type.upper()} \
tokenizer.path='{tokenizer_path}' \
tokenizer.prompt_template=null \
resume_from_checkpoint=True \
metric_logger.filename={log_file} \
clip_grad_norm=100 \
""".split()

cmd_2 = cmd_2 + self._get_test_config_overrides() + model_config

monkeypatch.setattr(sys, "argv", cmd_2)
runpy.run_path(TUNE_PATH, run_name="__main__")

expected_loss_values = self._fetch_expected_loss_values(model_type)[2:]

loss_values = get_loss_values_from_metric_logger(log_file)
torch.testing.assert_close(
loss_values, expected_loss_values, rtol=1e-4, atol=1e-4
)
2 changes: 1 addition & 1 deletion tests/recipes/test_full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch):
checkpointer._component_=torchtune.training.FullModelHFCheckpointer \
checkpointer.checkpoint_dir={tmpdir} \
checkpointer.checkpoint_files=[{os.path.join(tmpdir, "hf_model_0001_0.pt")}]\
checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")}
checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")}\
checkpointer.output_dir={tmpdir} \
checkpointer.model_type=LLAMA2 \
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
Expand Down
Loading

0 comments on commit 08efaed

Please sign in to comment.