diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index 50a61c1c0b..979728596b 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -120,7 +120,6 @@ class LoRAFinetuneRecipeSingleDevice(FTRecipeInterface): """ def __init__(self, cfg: DictConfig) -> None: - self._device = utils.get_device(device=cfg.device) # Reduced precision logic self._dtype = training.get_dtype(cfg.dtype, device=self._device) @@ -438,6 +437,9 @@ def _setup_model( # This is for any adapters that need to be initialized after base weights # have been loaded (e.g. DoRA). if self._is_dora: + for m in model.modules(): + if hasattr(m, "initialize_dora_magnitude"): + m.initialize_dora_magnitude() load_dora_magnitudes(model) if lora_weights_state_dict: lora_missing, lora_unexpected = model.load_state_dict( diff --git a/tests/recipes/test_lora_finetune_single_device.py b/tests/recipes/test_lora_finetune_single_device.py index ca10076f5f..d2521e4821 100644 --- a/tests/recipes/test_lora_finetune_single_device.py +++ b/tests/recipes/test_lora_finetune_single_device.py @@ -259,8 +259,9 @@ def test_training_state_on_resume( loss_values, expected_loss_values, rtol=1e-5, atol=1e-5 ) + @pytest.mark.parametrize("use_dora", [False, True]) @pytest.mark.integration_test - def test_save_and_load_merged_weights(self, tmpdir, monkeypatch): + def test_save_and_load_merged_weights(self, tmpdir, monkeypatch, use_dora): ckpt = "llama2_tune" ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) ckpt_dir = ckpt_path.parent @@ -280,7 +281,10 @@ def test_save_and_load_merged_weights(self, tmpdir, monkeypatch): enable_activation_offloading=False \ """.split() - model_config = MODEL_TEST_CONFIGS["llama2_lora"] + if use_dora: + model_config = MODEL_TEST_CONFIGS["llama2_dora"] + else: + model_config = MODEL_TEST_CONFIGS["llama2_lora"] cmd = cmd + self._get_test_config_overrides() + model_config monkeypatch.setattr(sys, "argv", cmd) diff --git a/tests/recipes/utils.py b/tests/recipes/utils.py index a79f5be715..baa8ad23a9 100644 --- a/tests/recipes/utils.py +++ b/tests/recipes/utils.py @@ -135,6 +135,7 @@ def lora_llama2_test_config( lora_rank: int = 8, lora_alpha: float = 16, quantize_base: bool = False, + use_dora: bool = False, ) -> List[str]: return [ # Note: we explicitly use _component_ so that we can also call @@ -154,6 +155,7 @@ def lora_llama2_test_config( f"model.lora_alpha={lora_alpha}", "model.lora_dropout=0.0", f"model.quantize_base={quantize_base}", + f"model.use_dora={use_dora}", ] @@ -207,6 +209,14 @@ def write_hf_ckpt_config(ckpt_dir: str): lora_rank=8, lora_alpha=16, ), + "llama2_dora": lora_llama2_test_config( + lora_attn_modules=["q_proj", "k_proj", "v_proj", "output_proj"], + apply_lora_to_mlp=False, + apply_lora_to_output=False, + lora_rank=8, + lora_alpha=16, + use_dora=True, + ), "llama2_qlora": lora_llama2_test_config( lora_attn_modules=["q_proj", "k_proj", "v_proj", "output_proj"], apply_lora_to_mlp=True, diff --git a/torchtune/models/convert_weights.py b/torchtune/models/convert_weights.py index c0cf2f10fc..b96006d33a 100644 --- a/torchtune/models/convert_weights.py +++ b/torchtune/models/convert_weights.py @@ -6,7 +6,7 @@ import re -from typing import Any, Dict +from typing import Any, Dict, Optional import torch @@ -252,23 +252,28 @@ def tune_to_peft_adapter_weights( num_heads: int = 32, num_kv_heads: int = 32, dim: int = 4096, - head_dim: int = None, + head_dim: Optional[int] = None, ): converted_state_dict = {} full_mapping = {} - # Rather than recreate a separate mapping for LoRA adapter weights, we just - # re-use the _FROM_HF mapping for base model weights. We iterate over it twice: - # once to add mappings for LoRA A matrices and once to add mappings for LoRA B matrices. - for k, v in _TO_PEFT_KEYS.items(): - full_mapping.update( - { - vv.replace(".weight", f".{k}.weight"): kk.replace( - ".weight", f".{v}.weight" - ) - for kk, vv in _FROM_HF.items() - if vv is not None - } - ) + # Rather than recreate a separate mapping for LoRA adapter weights, we re-use the + # _FROM_HF mapping for base model weights. The mapping is adapted to account for: + # LoRA A matrices, LoRA B matrices and the dora magnitude parameter. + for peft_key, peft_val in _TO_PEFT_KEYS.items(): + for hf_key, hf_val in _FROM_HF.items(): + if hf_val is None: + continue + + if peft_key == "magnitude": + # e.g. attn.q_proj.magnitude -> attn.q_proj.lora_magnitude_vector + adapter_key = hf_val.replace(".weight", f".{peft_key}") + adapter_val = hf_key.replace(".weight", f".{peft_val}") + else: + # e.g. attn.q_proj.lora_a.weight -> attn.q_proj.lora_A.weight + adapter_key = hf_val.replace(".weight", f".{peft_key}.weight") + adapter_val = hf_key.replace(".weight", f".{peft_val}.weight") + + full_mapping.update({adapter_key: adapter_val}) if head_dim is None: head_dim = dim // num_heads diff --git a/torchtune/modules/peft/dora.py b/torchtune/modules/peft/dora.py index 52ad9c7321..9e1428418f 100644 --- a/torchtune/modules/peft/dora.py +++ b/torchtune/modules/peft/dora.py @@ -79,6 +79,7 @@ def initialize_parameters(self): _lora_a_init_params(self.lora_a) _lora_b_init_params(self.lora_b) + @torch.no_grad() def initialize_dora_magnitude(self): """ DoRA initializes the magnitude vector such that its outputs are initially @@ -87,7 +88,7 @@ def initialize_dora_magnitude(self): base_weight = self.weight.to(self.lora_a.weight.dtype) lora_weight = self.lora_b.weight @ self.lora_a.weight weight_norm = self._get_weight_norm(base_weight, lora_weight) - self.magnitude = nn.Parameter(weight_norm, requires_grad=True) + self.magnitude.copy_(weight_norm) def _create_weight_and_bias(self): """