Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Recipe State Dict Code #1964

Merged
merged 3 commits into from
Nov 9, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/api_ref_modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ PEFT Components
peft.AdapterModule
peft.get_adapter_params
peft.set_trainable_params
peft.get_adapter_state_dict
peft.validate_missing_and_unexpected_for_lora
peft.validate_state_dict_for_lora
peft.disable_adapter
Expand Down
4 changes: 2 additions & 2 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,8 +645,8 @@ def save_checkpoint(

# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
cpu_state_dict = training.get_full_model_state_dict(
self._model,
cpu_state_dict = training.gather_cpu_state_dict(
self._model.state_dict(),
self._is_rank_zero,
device=self._device,
)
Expand Down
10 changes: 4 additions & 6 deletions recipes/knowledge_distillation_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from torchtune.modules.peft import (
DoRALinear,
get_adapter_params,
get_adapter_state_dict,
get_lora_module_names,
get_merged_lora_ckpt,
load_dora_magnitudes,
Expand Down Expand Up @@ -707,8 +708,8 @@ def save_checkpoint(self, epoch: int) -> None:
intermediate_checkpoint = epoch + 1 < self.total_epochs
# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
cpu_state_dict = training.get_full_model_state_dict(
self._model,
cpu_state_dict = training.gather_cpu_state_dict(
self._model.state_dict(),
self._is_rank_zero,
device=self._device,
)
Expand All @@ -728,10 +729,7 @@ def save_checkpoint(self, epoch: int) -> None:

# Filter out the adapter keys and weights from the model state dict. These will
# be saved separately
adapter_key_filter = lambda x: x in self.adapter_params
adapter_state_dict = {
k: v for k, v in cpu_state_dict.items() if adapter_key_filter(k)
}
adapter_state_dict = get_adapter_state_dict(cpu_state_dict)
checkpoint_dict.update({training.ADAPTER_KEY: adapter_state_dict})

# merge the adapter weights and base weights to create the model checkpoint
Expand Down
6 changes: 2 additions & 4 deletions recipes/knowledge_distillation_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from torchtune.datasets import ConcatDataset
from torchtune.modules.peft import (
get_adapter_params,
get_adapter_state_dict,
get_lora_module_names,
get_merged_lora_ckpt,
load_dora_magnitudes,
Expand Down Expand Up @@ -586,10 +587,7 @@ def save_checkpoint(self, epoch: int) -> None:
ckpt_dict.update({training.MODEL_KEY: merged_state_dict})

# Construct the adapter weights
adapter_key_filter = lambda x: x in self.adapter_params
adapter_state_dict = {
k: v for k, v in self._model.state_dict().items() if adapter_key_filter(k)
}
adapter_state_dict = get_adapter_state_dict(self._model.state_dict())
ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict})
adapter_config = {
"r": self._lora_rank,
Expand Down
10 changes: 4 additions & 6 deletions recipes/lora_dpo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
disable_adapter,
DoRALinear,
get_adapter_params,
get_adapter_state_dict,
get_merged_lora_ckpt,
load_dora_magnitudes,
LoRALinear,
Expand Down Expand Up @@ -504,8 +505,8 @@ def save_checkpoint(
intermediate_checkpoint = epoch + 1 < self.total_epochs
# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
cpu_state_dict = training.get_full_model_state_dict(
self._model,
cpu_state_dict = training.gather_cpu_state_dict(
self._model.state_dict(),
self._is_rank_zero,
device=self._device,
)
Expand All @@ -524,10 +525,7 @@ def save_checkpoint(

# Filter out the adapter keys and weights from the model state dict. These will
# be saved separately
adapter_key_filter = lambda x: x in self.adapter_params
adapter_state_dict = {
k: v for k, v in cpu_state_dict.items() if adapter_key_filter(k)
}
adapter_state_dict = get_adapter_state_dict(cpu_state_dict)
checkpoint_dict.update({training.ADAPTER_KEY: adapter_state_dict})

# merge the adapter weights and base weights to create the model checkpoint
Expand Down
3 changes: 2 additions & 1 deletion recipes/lora_dpo_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from torchtune.modules.peft import (
disable_adapter,
get_adapter_params,
get_adapter_state_dict,
get_merged_lora_ckpt,
set_trainable_params,
validate_missing_and_unexpected_for_lora,
Expand Down Expand Up @@ -407,7 +408,7 @@ def save_checkpoint(self, epoch: int) -> None:
}
)

adapter_state_dict = {k: v.cpu() for k, v in self.adapter_params.items()}
adapter_state_dict = get_adapter_state_dict(self._model.state_dict())
ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict})
if not self._save_adapter_weights_only:
# Construct the full state dict with LoRA weights merged into base LLM weights
Expand Down
13 changes: 5 additions & 8 deletions recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from torchtune.modules.peft import (
DoRALinear,
get_adapter_params,
get_adapter_state_dict,
get_lora_module_names,
get_merged_lora_ckpt,
load_dora_magnitudes,
Expand Down Expand Up @@ -452,8 +453,7 @@ def _setup_model(
with training.set_default_dtype(self._dtype), torch.device("meta"):
model = config.instantiate(cfg_model)

self.adapter_params = get_adapter_params(model)
set_trainable_params(model, self.adapter_params)
set_trainable_params(model, get_adapter_params(model))

if self._compile:
training.compile_model(model, verbose=self._is_rank_zero)
Expand Down Expand Up @@ -664,8 +664,8 @@ def save_checkpoint(

# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
cpu_state_dict = training.get_full_model_state_dict(
self._model,
cpu_state_dict = training.gather_cpu_state_dict(
self._model.state_dict(),
self._is_rank_zero,
device=self._device,
trainable_only=self._save_adapter_weights_only,
Expand Down Expand Up @@ -696,10 +696,7 @@ def save_checkpoint(
start = time.perf_counter()
# Filter out the adapter keys and weights from the model state dict. These will
# be saved separately
adapter_key_filter = lambda x: x in self.adapter_params
adapter_state_dict = {
k: v for k, v in cpu_state_dict.items() if adapter_key_filter(k)
}
adapter_state_dict = get_adapter_state_dict(cpu_state_dict)
checkpoint_dict.update({training.ADAPTER_KEY: adapter_state_dict})

# merge the adapter weights and base weights to create the model checkpoint
Expand Down
3 changes: 2 additions & 1 deletion recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from torchtune.datasets import ConcatDataset
from torchtune.modules.peft import (
get_adapter_params,
get_adapter_state_dict,
get_lora_module_names,
get_merged_lora_ckpt,
load_dora_magnitudes,
Expand Down Expand Up @@ -587,7 +588,7 @@ def save_checkpoint(self, epoch: int) -> None:
}
)

adapter_state_dict = {k: v.cpu() for k, v in self.adapter_params.items()}
adapter_state_dict = get_adapter_state_dict(self._model.state_dict())
ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict})

if not self._save_adapter_weights_only:
Expand Down
4 changes: 2 additions & 2 deletions recipes/qat_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,8 +555,8 @@ def save_checkpoint(
intermediate_checkpoint = epoch + 1 < self.total_epochs
# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
cpu_state_dict = training.get_full_model_state_dict(
self._model,
cpu_state_dict = training.gather_cpu_state_dict(
self._model.state_dict(),
self._is_rank_zero,
)

Expand Down
33 changes: 24 additions & 9 deletions tests/torchtune/modules/peft/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
disable_adapter,
DoRALinear,
get_adapter_params,
get_adapter_state_dict,
get_merged_lora_ckpt,
LoRALinear,
set_trainable_params,
Expand All @@ -38,30 +39,30 @@
class DummyAdapterModule(nn.Module, AdapterModule):
def __init__(self, in_dim, out_dim):
super().__init__()
self.adapter = nn.Linear(in_dim, out_dim, bias=False)
self.lora = nn.Linear(in_dim, out_dim, bias=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So as of now what is the actual value of AdapterModule? Is it just for setting trainable params?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think AdapterModule is the right way to go, but since we are already ignoring it for ckpt merging, I'm not really changing anything by not using it for get_adapter_state_dict. I think we should move to using AdapterModule for all of these functions but that doesn't need to be solved in this PR.

self.linear = nn.Linear(in_dim, out_dim)

def adapter_params(self):
return ["adapter.weight"]
return ["lora.weight"]

def forward(self, x):
return self.adapter(x) + self.non_adapter(x)
return self.lora(x) + self.non_adapter(x)


class DummyAdapterParentModel(nn.Module, AdapterModule):
def __init__(self, in_dim, out_dim):
super().__init__()
self.dummy_adapter_module = DummyAdapterModule(in_dim, out_dim)
self.parent_adapter = nn.Linear(in_dim, out_dim)
self.parent_lora = nn.Linear(in_dim, out_dim)
self.parent_base_model = nn.Linear(in_dim, out_dim)

def adapter_params(self):
return ["parent_adapter.weight", "parent_adapter.bias"]
return ["parent_lora.weight", "parent_lora.bias"]

def forward(self, x):
return (
self.dummy_adapter_module(x)
+ self.parent_adapter(x)
+ self.parent_lora(x)
+ self.parent_base_model(x)
)

Expand All @@ -79,9 +80,9 @@ def dummy_model_expected_adapter_keys():
for i in range(N_LAYERS):
keys.extend(
[
f"{i}.parent_adapter.weight",
f"{i}.parent_adapter.bias",
f"{i}.dummy_adapter_module.adapter.weight",
f"{i}.parent_lora.weight",
f"{i}.parent_lora.bias",
f"{i}.dummy_adapter_module.lora.weight",
]
)
return keys
Expand Down Expand Up @@ -204,6 +205,20 @@ def test_get_adapter_params(self, request, model_name, expected_keys):
expected = request.getfixturevalue(expected_keys)
assert set(expected) == set(adapter_params.keys())

@pytest.mark.parametrize(
"model_name, expected_keys",
[
("dummy_adapter_parent_model", "dummy_model_expected_adapter_keys"),
("lora_llama2_model", "lora_llama2_expected_adapter_keys"),
("dora_llama2_model", "dora_llama2_expected_adapter_keys"),
],
)
def test_get_adapter_state_dict(self, request, model_name, expected_keys):
model = request.getfixturevalue(model_name)
adapter_state_dict = get_adapter_state_dict(model.state_dict())
expected = request.getfixturevalue(expected_keys)
assert set(expected) == set(adapter_state_dict.keys())

@pytest.mark.parametrize(
"model_name, expected_trainable_keys, expected_frozen_keys",
[
Expand Down
8 changes: 4 additions & 4 deletions tests/torchtune/training/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,8 +312,8 @@ def test_lora_state_dict(self):
fsdp_optim_to_save.zero_grad()
expected_model_sd = base_model.state_dict()
expected_optim_sd = base_optim.state_dict()
model_full_sd = training.get_full_model_state_dict(
fsdp_model_to_save, is_rank_zero
model_full_sd = training.gather_cpu_state_dict(
fsdp_model_to_save.state_dict(), is_rank_zero
)
optim_full_sd = training.get_full_optimizer_state_dict(
fsdp_optim_to_save,
Expand Down Expand Up @@ -467,8 +467,8 @@ def _test_qlora_state_dict(self, enable_activation_checkpointing: bool):
fsdp_model_to_save(inp)

expected_model_sd = {k: v.cpu() for k, v in base_model.state_dict().items()}
model_full_sd = training.get_full_model_state_dict(
fsdp_model_to_save, is_rank_zero
model_full_sd = training.gather_cpu_state_dict(
fsdp_model_to_save.state_dict(), is_rank_zero
)
if is_rank_zero:
self.assertEqual(set(model_full_sd.keys()), set(expected_model_sd.keys()))
Expand Down
2 changes: 2 additions & 0 deletions torchtune/modules/peft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
AdapterModule,
disable_adapter,
get_adapter_params,
get_adapter_state_dict,
get_lora_module_names,
get_merged_lora_ckpt,
load_dora_magnitudes,
Expand All @@ -30,6 +31,7 @@
"validate_state_dict_for_lora",
"load_dora_magnitudes",
"disable_adapter",
"get_adapter_state_dict",
"get_merged_lora_ckpt",
"get_lora_module_names",
]
18 changes: 18 additions & 0 deletions torchtune/modules/peft/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,24 @@ def get_lora_module_names(
return lora_module_keys


def get_adapter_state_dict(state_dict: Dict[str, Any]) -> Dict[str, Any]:
"""
Return the subset of the full state_dict from a model that correspond to an adapter.
Assumes that "lora" and "magnitude" are unique names for adapter parameters, and
that the state_dict is not sharded. All returned parameters are moved to CPU.

Args:
state_dict (Dict[str, Any]): Full model state dict.

Returns:
Dict[str, Any]: the subset of model's state dict containing
only adapter parameters.

"""
adapter_key_filter = lambda x: "lora" in x or "magnitude" in x
return {k: v.cpu() for k, v in state_dict.items() if adapter_key_filter(k)}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to make the move to CPU optional here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's already on cpu this is a no op



def validate_state_dict_for_lora(
lora_attn_modules: List[LORA_ATTN_MODULES],
apply_lora_to_mlp: bool,
Expand Down
4 changes: 2 additions & 2 deletions torchtune/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from torchtune.training._distributed import (
contains_fsdp,
FSDPPolicyType,
gather_cpu_state_dict,
get_full_finetune_fsdp_wrap_policy,
get_full_model_state_dict,
get_full_optimizer_state_dict,
get_shard_conditions,
get_world_size_and_rank,
Expand Down Expand Up @@ -120,7 +120,7 @@
"FSDPPolicyType",
"get_full_finetune_fsdp_wrap_policy",
"lora_fsdp_wrap_policy",
"get_full_model_state_dict",
"gather_cpu_state_dict",
"get_full_optimizer_state_dict",
"load_from_full_model_state_dict",
"load_from_full_optimizer_state_dict",
Expand Down
Loading
Loading