Skip to content

Commit

Permalink
DoRA fixes (#2139)
Browse files Browse the repository at this point in the history
Co-authored-by: Mircea Mironenco <[email protected]>
  • Loading branch information
ebsmothers and mirceamironenco authored Dec 11, 2024
1 parent f4d56e3 commit 9cfa288
Show file tree
Hide file tree
Showing 13 changed files with 259 additions and 53 deletions.
9 changes: 2 additions & 7 deletions recipes/knowledge_distillation_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
get_adapter_state_dict,
get_lora_module_names,
get_merged_lora_ckpt,
load_dora_magnitudes,
LoRALinear,
set_trainable_params,
validate_missing_and_unexpected_for_lora,
Expand Down Expand Up @@ -477,8 +476,7 @@ def _setup_model(
) and not lora_weights_state_dict:
# lora may not be covered in state dict
# if finetune for the 1st time
m.lora_a.to_empty(device=lora_device)
m.lora_b.to_empty(device=lora_device)
m.to_empty(device=lora_device)
m.initialize_parameters()
# RoPE is not covered in state dict
if hasattr(m, "rope_init"):
Expand All @@ -491,13 +489,10 @@ def _setup_model(
self._is_rank_zero,
cpu_offload=fsdp_cpu_offload,
)
is_dora = False
for m in model.modules():
if hasattr(m, "initialize_dora_magnitude"):
is_dora = True
m.initialize_dora_magnitude()
if is_dora:
load_dora_magnitudes(model)

validate_missing_and_unexpected_for_lora(
lora_attn_modules=self._lora_attn_modules,
apply_lora_to_mlp=self._apply_lora_to_mlp,
Expand Down
5 changes: 3 additions & 2 deletions recipes/knowledge_distillation_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
get_adapter_state_dict,
get_lora_module_names,
get_merged_lora_ckpt,
load_dora_magnitudes,
set_trainable_params,
validate_missing_and_unexpected_for_lora,
)
Expand Down Expand Up @@ -421,7 +420,9 @@ def _setup_model(
# This is for any adapters that need to be initialized after base weights
# have been loaded (e.g. DoRA).
if self._is_dora:
load_dora_magnitudes(model)
for m in model.modules():
if hasattr(m, "initialize_dora_magnitude"):
m.initialize_dora_magnitude()
if lora_weights_state_dict:
lora_missing, lora_unexpected = model.load_state_dict(
lora_weights_state_dict, strict=False
Expand Down
8 changes: 4 additions & 4 deletions recipes/lora_dpo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
get_adapter_params,
get_adapter_state_dict,
get_merged_lora_ckpt,
load_dora_magnitudes,
LoRALinear,
set_trainable_params,
validate_missing_and_unexpected_for_lora,
Expand Down Expand Up @@ -400,8 +399,7 @@ def _setup_model(
) and not lora_weights_state_dict:
# lora may not be covered in state dict
# if finetune for the 1st time
m.lora_a.to_empty(device=lora_device)
m.lora_b.to_empty(device=lora_device)
m.to_empty(device=lora_device)
m.initialize_parameters()
# RoPE is not covered in state dict
if hasattr(m, "rope_init"):
Expand All @@ -420,7 +418,9 @@ def _setup_model(
is_dora = True
m.initialize_dora_magnitude()
if is_dora:
load_dora_magnitudes(model)
for m in model.modules():
if hasattr(m, "initialize_dora_magnitude"):
m.initialize_dora_magnitude()
validate_missing_and_unexpected_for_lora(
lora_attn_modules=self._lora_attn_modules,
apply_lora_to_mlp=self._apply_lora_to_mlp,
Expand Down
11 changes: 3 additions & 8 deletions recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
get_adapter_state_dict,
get_lora_module_names,
get_merged_lora_ckpt,
load_dora_magnitudes,
LoRALinear,
set_trainable_params,
validate_missing_and_unexpected_for_lora,
Expand Down Expand Up @@ -496,10 +495,9 @@ def _setup_model(
) and not lora_weights_state_dict:
# lora may not be covered in state dict
# if finetune for the 1st time
m.lora_a.to_empty(device=lora_device)
m.lora_b.to_empty(device=lora_device)
m.to_empty(device=lora_device)
m.initialize_parameters()
# RoPE is not covered in state dict

if hasattr(m, "rope_init"):
m.rope_init()

Expand All @@ -510,13 +508,10 @@ def _setup_model(
self._is_rank_zero,
cpu_offload=fsdp_cpu_offload,
)
is_dora = False
for m in model.modules():
if hasattr(m, "initialize_dora_magnitude"):
is_dora = True
m.initialize_dora_magnitude()
if is_dora:
load_dora_magnitudes(model)

validate_missing_and_unexpected_for_lora(
lora_attn_modules=self._lora_attn_modules,
apply_lora_to_mlp=self._apply_lora_to_mlp,
Expand Down
2 changes: 0 additions & 2 deletions recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
get_adapter_state_dict,
get_lora_module_names,
get_merged_lora_ckpt,
load_dora_magnitudes,
set_trainable_params,
validate_missing_and_unexpected_for_lora,
)
Expand Down Expand Up @@ -450,7 +449,6 @@ def _setup_model(
for m in model.modules():
if hasattr(m, "initialize_dora_magnitude"):
m.initialize_dora_magnitude()
load_dora_magnitudes(model)
if lora_weights_state_dict:
lora_missing, lora_unexpected = model.load_state_dict(
lora_weights_state_dict, strict=False
Expand Down
3 changes: 1 addition & 2 deletions recipes/qat_lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,8 +534,7 @@ def _setup_model(
) and not lora_weights_state_dict:
# lora may not be covered in state dict
# if finetune for the 1st time
m.lora_a.to_empty(device=lora_device)
m.lora_b.to_empty(device=lora_device)
m.to_empty(device=lora_device)
m.initialize_parameters()
# RoPE is not covered in state dict
if hasattr(m, "rope_init"):
Expand Down
14 changes: 7 additions & 7 deletions tests/recipes/test_lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,15 +215,15 @@ def test_training_state_on_resume(

@pytest.mark.integration_test
@pytest.mark.parametrize(
"recipe_config, model_type, ckpt_type",
"recipe_config, model_type, ckpt_type, use_dora",
[
("llama2/7B_lora", "llama2", "tune"),
("llama3/8B_lora", "llama3", "tune"),
("llama2/7B_lora", "llama2", "tune", True),
("llama3/8B_lora", "llama3", "tune", False),
],
)
@gpu_test(gpu_count=2)
def test_save_and_load_merged_weights(
self, recipe_config, model_type, ckpt_type, tmpdir, monkeypatch
self, recipe_config, model_type, ckpt_type, use_dora, tmpdir, monkeypatch
):
ckpt_component = CKPT_COMPONENT_MAP[ckpt_type]
ckpt = model_type + "_" + ckpt_type
Expand All @@ -249,9 +249,9 @@ def test_save_and_load_merged_weights(
enable_activation_checkpointing=True \
enable_activation_offloading=True \
""".split()

model_config = MODEL_TEST_CONFIGS[model_type + "_lora"]

model_config = MODEL_TEST_CONFIGS[
model_type + ("_dora" if use_dora else "_lora")
]
cmd = cmd + self._get_test_config_overrides() + model_config
monkeypatch.setattr(sys, "argv", cmd)
runpy.run_path(TUNE_PATH, run_name="__main__")
Expand Down
9 changes: 2 additions & 7 deletions tests/torchtune/models/llama2/scripts/compare_dora.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,7 @@
from torch import nn
from torchao.dtypes.nf4tensor import linear_nf4, to_nf4
from torchtune import training
from torchtune.modules.peft import (
DoRALinear,
get_merged_lora_ckpt,
load_dora_magnitudes,
LoRALinear,
)
from torchtune.modules.peft import DoRALinear, get_merged_lora_ckpt, LoRALinear
from torchtune.training.seed import set_seed


Expand Down Expand Up @@ -91,7 +86,7 @@ def _dora_is_the_same_as_lora():
# Verify that this is true.
assert not _dora_is_the_same_as_lora()
module.initialize_dora_magnitude()
load_dora_magnitudes(module)

assert _dora_is_the_same_as_lora()

def _compare_params():
Expand Down
Loading

0 comments on commit 9cfa288

Please sign in to comment.