Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove unused FSDP components #2016

Merged
merged 7 commits into from
Nov 16, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/source/api_ref_modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ PEFT Components
peft.set_trainable_params
peft.get_adapter_state_dict
peft.validate_missing_and_unexpected_for_lora
peft.validate_state_dict_for_lora
Copy link
Collaborator

@SalmanMohammadi SalmanMohammadi Nov 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 209 in the LoRA finetune tutorial:

.. note::
    Whenever loading weights with :code:`strict=False`, you should verify that any missing or extra keys in
    the loaded :code:`state_dict` are as expected. torchtune's LoRA recipes do this by default via e.g.
    :func:`validate_state_dict_for_lora() <torchtune.modules.peft.validate_state_dict_for_lora>` or
    :func:`validate_missing_and_unexpected_for_lora() <torchtune.modules.peft.validate_missing_and_unexpected_for_lora>`.

Needs to be updated

peft.disable_adapter


Expand Down
3 changes: 0 additions & 3 deletions docs/source/api_ref_training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,9 @@ Utilities for enabling and working with distributed training.
:toctree: generated/
:nosignatures:

FSDPPolicyType
init_distributed
is_distributed
get_world_size_and_rank
get_full_finetune_fsdp_wrap_policy
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a reference to this in the QAT tutorial too.

lora_fsdp_wrap_policy
gather_cpu_state_dict

.. _ac_label:
Expand Down
14 changes: 0 additions & 14 deletions recipes/lora_dpo_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
get_merged_lora_ckpt,
set_trainable_params,
validate_missing_and_unexpected_for_lora,
validate_state_dict_for_lora,
)
from torchtune.recipe_interfaces import FTRecipeInterface

Expand Down Expand Up @@ -271,19 +270,6 @@ def _setup_model(
model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
)

validate_state_dict_for_lora(
lora_attn_modules=cfg_model.lora_attn_modules,
apply_lora_to_mlp=cfg_model.apply_lora_to_mlp,
apply_lora_to_output=getattr(cfg_model, "apply_lora_to_output", False),
full_model_state_dict_keys=model.state_dict().keys(),
lora_state_dict_keys=(
lora_weights_state_dict.keys()
if lora_weights_state_dict is not None
else None
),
base_model_state_dict_keys=base_model_state_dict.keys(),
)

base_missing, base_unexpected = model.load_state_dict(
base_model_state_dict, strict=False
)
Expand Down
195 changes: 62 additions & 133 deletions tests/torchtune/modules/peft/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
LoRALinear,
set_trainable_params,
validate_missing_and_unexpected_for_lora,
validate_state_dict_for_lora,
)

N_LAYERS = 3
Expand Down Expand Up @@ -261,9 +260,10 @@ def test_set_trainable_params(
lora_attn_modules,
apply_lora_to_mlp,
apply_lora_to_output,
full_model_state_dict_keys,
lora_state_dict_keys,
base_model_state_dict_keys,
base_missing,
base_unexpected,
lora_missing,
lora_unexpected,
expected
"""
),
Expand All @@ -272,188 +272,117 @@ def test_set_trainable_params(
["q_proj", "k_proj"],
False,
False,
["q_proj.lora_a.weight", "dummy_param.weight"],
["q_proj.lora_a.weight"],
[],
["dummy_param.weight"],
[],
"",
),
(
["v_proj"],
False,
False,
["param_a", "param_b"],
None,
["param_a", "param_b"],
"",
),
(["v_proj"], False, False, [], [], ["param_a", "param_b"], [], ""),
(
["output_proj"],
False,
True,
["output_proj.weight", "output_proj.lora_a.weight"],
["output_proj.lora_a.weight"],
[],
["output_proj.weight"],
[],
"",
),
(["q_proj"], False, False, ["param_a"], [], [], "Missing non-LoRA"),
(
["k_proj", "output_proj"],
["q_proj"],
False,
True,
["k_proj.lora_a.weight", "param_a"],
["k_proj.lora_a.weight", "param_a"],
False,
["param_a"],
[],
["param_a"],
"found in LoRA",
[],
"Missing non-LoRA",
),
(
["k_proj"],
False,
["k_proj", "output_proj"],
False,
["k_proj.lora_a.weight"],
True,
[],
[],
["k_proj.lora_a.weight"],
"found in base model",
[],
"Missing LoRA key",
),
(
["k_proj"],
False,
["q_proj", "k_proj"],
True,
False,
["k_proj.lora_a.weight"],
["k_proj.lora"],
[],
["q_proj.lora"],
[],
None,
"Missing LoRA",
),
(["q_proj"], False, False, [], ["a"], ["a"], "overlapping"),
(
["v_proj"],
False,
False,
["dummy_param.weight"],
["v_proj.lora_a.weight"],
["dummy_param.weight"],
"Extra",
),
(
["w1", "w2", "w3"],
["q_proj", "k_proj"],
True,
False,
["w1.lora_a.weight", "w2.weight", "q_proj.weight"],
["w1.lora_a.weight"],
["q_proj.weight"],
"Missing non-LoRA key",
["k_proj.lora"],
[],
["q_proj.magnitude"],
[],
"Missing LoRA",
),
(
["q_proj", "output"],
False,
["q_proj", "k_proj"],
True,
[
"q_proj.lora_a",
"output.weight",
"output.lora_a",
"output_proj.lora_b",
],
["q_proj.lora_a", "output.lora_a", "output_proj.lora_b"],
["output.weight"],
"Missing non-LoRA key",
),
(
["q_proj", "v_proj"],
False,
False,
"lora_llama2_model_all_keys",
"lora_llama2_expected_adapter_keys",
"lora_llama2_expected_base_model_keys",
"",
["output_proj.lora"],
[],
["q_proj.lora"],
[],
"Missing non-LoRA",
),
(
["q_proj", "v_proj"],
False,
["q_proj", "k_proj"],
True,
False,
"dora_llama2_model_all_keys",
"dora_llama2_expected_adapter_keys",
"lora_llama2_expected_base_model_keys",
"",
),
],
)
def test_validate_lora_state_dict(
self,
request,
lora_attn_modules,
apply_lora_to_mlp,
apply_lora_to_output,
full_model_state_dict_keys,
lora_state_dict_keys,
base_model_state_dict_keys,
expected,
):
if isinstance(full_model_state_dict_keys, str):
full_model_state_dict_keys = request.getfixturevalue(
full_model_state_dict_keys
)
if isinstance(lora_state_dict_keys, str):
lora_state_dict_keys = request.getfixturevalue(lora_state_dict_keys)
if isinstance(base_model_state_dict_keys, str):
base_model_state_dict_keys = request.getfixturevalue(
base_model_state_dict_keys
)
if expected:
with pytest.raises(AssertionError, match=expected):
validate_state_dict_for_lora(
lora_attn_modules,
apply_lora_to_mlp,
apply_lora_to_output,
full_model_state_dict_keys=full_model_state_dict_keys,
lora_state_dict_keys=lora_state_dict_keys,
base_model_state_dict_keys=base_model_state_dict_keys,
)
else:
validate_state_dict_for_lora(
lora_attn_modules,
apply_lora_to_mlp,
apply_lora_to_output,
full_model_state_dict_keys=full_model_state_dict_keys,
lora_state_dict_keys=lora_state_dict_keys,
base_model_state_dict_keys=base_model_state_dict_keys,
)

@pytest.mark.parametrize(
(
"""
base_missing,
base_unexpected,
lora_missing,
lora_unexpected,
expected
"""
),
[
(["k_proj.lora"], [], ["q_proj.lora"], [], "Missing LoRA"),
(["k_proj.lora"], [], ["q_proj.magnitude"], [], "Missing LoRA"),
(["output_proj.lora"], [], ["q_proj.lora"], [], "Missing non-LoRA"),
(
["k_proj.lora"],
["output.weight"],
["q_proj.base_weight"],
[],
"loading base model",
),
(
["q_proj", "k_proj"],
True,
False,
["k_proj.lora"],
[],
["q_proj.base_weight"],
["output.weight"],
"loading adapter",
),
(["k_proj.lora"], [], ["q_proj.base_weight"], [], ""),
(
["q_proj", "k_proj"],
True,
False,
["k_proj.lora"],
[],
["q_proj.base_weight"],
[],
"",
),
],
)
def test_validate_missing_and_unexpected_for_lora(
self, base_missing, base_unexpected, lora_missing, lora_unexpected, expected
self,
lora_attn_modules,
apply_lora_to_mlp,
apply_lora_to_output,
base_missing,
base_unexpected,
lora_missing,
lora_unexpected,
expected,
):
lora_attn_modules = ["q_proj", "k_proj"]
apply_lora_to_mlp = True
apply_lora_to_output = False

if expected:
with pytest.raises(AssertionError, match=expected):
validate_missing_and_unexpected_for_lora(
Expand Down
Loading
Loading