diff --git a/recipes/configs/llama3_2_vision/11B_lora_single_device.yaml b/recipes/configs/llama3_2_vision/11B_lora_single_device.yaml index c248ccaee8..050c6b0383 100644 --- a/recipes/configs/llama3_2_vision/11B_lora_single_device.yaml +++ b/recipes/configs/llama3_2_vision/11B_lora_single_device.yaml @@ -45,7 +45,7 @@ checkpointer: output_dir: /tmp/Llama-3.2-11B-Vision-Instruct/ model_type: LLAMA3_VISION resume_from_checkpoint: False -save_adapter_weights_only: False # PeFT formatting not available yet. This will save it in torchtune format only. +save_adapter_weights_only: False # Dataset dataset: diff --git a/recipes/configs/llama3_2_vision/11B_qlora_single_device.yaml b/recipes/configs/llama3_2_vision/11B_qlora_single_device.yaml index 8261d8eeac..2829cb4d43 100644 --- a/recipes/configs/llama3_2_vision/11B_qlora_single_device.yaml +++ b/recipes/configs/llama3_2_vision/11B_qlora_single_device.yaml @@ -45,7 +45,7 @@ checkpointer: output_dir: /tmp/Llama-3.2-11B-Vision-Instruct/ model_type: LLAMA3_VISION resume_from_checkpoint: False -save_adapter_weights_only: False # PeFT formatting not available yet. This will save it in torchtune format only. +save_adapter_weights_only: False # Dataset dataset: diff --git a/torchtune/models/llama3_2_vision/_convert_weights.py b/torchtune/models/llama3_2_vision/_convert_weights.py index ae3f025c91..1957d354d9 100644 --- a/torchtune/models/llama3_2_vision/_convert_weights.py +++ b/torchtune/models/llama3_2_vision/_convert_weights.py @@ -427,3 +427,103 @@ def _permute(t, n_heads): converted_state_dict[new_key] = value return converted_state_dict + + +# Mapping from torchtune LoRA module names to PEFT LoRA module names +_TO_PEFT_KEYS = { + "lora_a": "lora_A", + "lora_b": "lora_B", + "magnitude": "lora_magnitude_vector", +} + + +def _get_peft_dict(tune_to_hf_dict: Dict[str, str]) -> Dict[str, torch.Tensor]: + """ + Rather than recreate a separate mapping for LoRA adapter weights, we just + re-use the _FROM_HF mapping for base model weights. We iterate over it twice: + once to add mappings for LoRA A matrices and once to add mappings for LoRA B matrices. + """ + new_mapping_dict = {} + for tune_peft, hf_peft in _TO_PEFT_KEYS.items(): + for tune_key, hf_key in tune_to_hf_dict.items(): + if hf_key is None or tune_key is None: + continue + + if tune_peft == "magnitude": + # e.g. attn.q_proj.magnitude -> attn.q_proj.lora_magnitude_vector + tune_adapter = tune_key.replace(".weight", f".{tune_peft}") + hf_adapter = hf_key.replace(".weight", f".{hf_peft}") + else: + # e.g. attn.q_proj.lora_a.weight -> attn.q_proj.lora_A.weight + tune_adapter = tune_key.replace(".weight", f".{tune_peft}.weight") + hf_adapter = hf_key.replace(".weight", f".{hf_peft}.weight") + + new_mapping_dict[tune_adapter] = hf_adapter + return new_mapping_dict + + +def llama3_vision_tune_to_peft_adapter_weights( + state_dict: Dict[str, torch.Tensor], + num_heads: int = 32, + num_kv_heads: int = 32, + dim: int = 4096, + head_dim: int = None, + cross_attention_layers: Optional[List[int]] = None, +) -> Dict[str, torch.Tensor]: + """ + Convertor from Tune state dict to HF state dict. This handles: + - Updateing the cross attention layer numbers + - skip loading the rope embeddings + - reshaping q, k projections + """ + converted_state_dict = {} + inverted_mapping_dict = {v: k for k, v in _FROM_HF.items()} + # missing keys in _FROM_HF due to naming collisions + missing_keys = { + "decoder.layers.{}.fusion_layer.ca_norm.scale": "language_model.model.layers.{}.input_layernorm.weight", + "decoder.layers.{}.fusion_layer.mlp_norm.scale": "language_model.model.layers.{}.post_attention_layernorm.weight", + "decoder.layers.{}.fusion_layer.mlp.w1.weight": "language_model.model.layers.{}.mlp.gate_proj.weight", + "decoder.layers.{}.fusion_layer.mlp.w3.weight": "language_model.model.layers.{}.mlp.up_proj.weight", + "decoder.layers.{}.fusion_layer.mlp.w2.weight": "language_model.model.layers.{}.mlp.down_proj.weight", + "decoder.tok_embeddings.fusion_embedding.weight": None, + } + inverted_mapping_dict.update(missing_keys) + inverted_mapping_dict = _get_peft_dict(inverted_mapping_dict) + + if head_dim is None: + head_dim = dim // num_heads + if cross_attention_layers is None: + cross_attention_layers = [] + # convert hf layer numbers to tune numbers + cross_attention_layers = [ + l - i for i, l in enumerate(sorted(cross_attention_layers)) + ] + + def _permute_lora_matrix(t, n_heads): + rank = t.shape[-1] + return ( + t.view(n_heads, head_dim // 2, 2, rank) + .transpose(1, 2) + .reshape((head_dim * n_heads), rank) + ) + + for key, value in state_dict.items(): + new_key = get_mapped_key(key, inverted_mapping_dict) + if "decoder" in key: + if "layers" in key: # Update layer numbers + layer = int(key.split(".")[2]) + num_shifts = sum(layer > l for l in cross_attention_layers) + new_layer = layer + num_shifts + key_lst = new_key.split(".") + if layer in cross_attention_layers and "fusion_layer" not in key: + new_layer += 1 # hf treats the fusion_layer as an additional layer + key_lst[3] = str(new_layer) + new_key = ".".join(key_lst) + if "q_proj" in key and "lora_B" in new_key and "cross_attn" not in new_key: + value = _permute_lora_matrix(value, num_heads) + elif ( + "k_proj" in key and "lora_B" in new_key and "cross_attn" not in new_key + ): + value = _permute_lora_matrix(value, num_kv_heads) + converted_state_dict["base_model.model." + new_key] = value + return converted_state_dict diff --git a/torchtune/training/checkpointing/_checkpointer.py b/torchtune/training/checkpointing/_checkpointer.py index ec79f0a4ba..94a315cafc 100644 --- a/torchtune/training/checkpointing/_checkpointer.py +++ b/torchtune/training/checkpointing/_checkpointer.py @@ -635,6 +635,7 @@ def save_checkpoint( ) if training.ADAPTER_KEY in state_dict: + # import pdb; pdb.set_trace() # Save torchtune format adapter weights even if we save PEFT format # This way we can resume no matter what (and memory footprint of adapter weights is small) output_path = Path.joinpath( @@ -651,20 +652,38 @@ def save_checkpoint( logger.warning( "Saving Phi-3 Mini adapter weights to PEFT format is not supported, saving to torchtune format instead" ) - elif self._model_type == ModelType.LLAMA3_VISION: + elif self._model_type == ModelType.QWEN2: logger.warning( - "Saving Llama3.2 Vision adapter weights to PEFT format is not supported, saving to torchtune format instead" + "Saving QWEN2 adapter weights to PEFT format is not supported, saving to torchtune format instead" ) else: - state_dict[ - training.ADAPTER_KEY - ] = convert_weights.tune_to_peft_adapter_weights( - state_dict[training.ADAPTER_KEY], - num_heads=self._config["num_attention_heads"], - num_kv_heads=self._config["num_key_value_heads"], - dim=self._config["hidden_size"], - head_dim=self._config.get("head_dim", None), - ) + if self._model_type == ModelType.LLAMA3_VISION: + from torchtune.models.llama3_2_vision._convert_weights import ( + llama3_vision_tune_to_peft_adapter_weights, + ) + + state_dict[ + training.ADAPTER_KEY + ] = llama3_vision_tune_to_peft_adapter_weights( + state_dict[training.ADAPTER_KEY], + num_heads=text_config["num_attention_heads"], + num_kv_heads=text_config["num_key_value_heads"], + dim=text_config["hidden_size"], + head_dim=text_config.get("head_dim", None), + cross_attention_layers=text_config.get( + "cross_attention_layers", None + ), + ) + else: + state_dict[ + training.ADAPTER_KEY + ] = convert_weights.tune_to_peft_adapter_weights( + state_dict[training.ADAPTER_KEY], + num_heads=self._config["num_attention_heads"], + num_kv_heads=self._config["num_key_value_heads"], + dim=self._config["hidden_size"], + head_dim=self._config.get("head_dim", None), + ) peft_output_path = Path.joinpath( self._output_dir, "adapter_model" ).with_suffix(".bin") @@ -680,28 +699,19 @@ def save_checkpoint( ) if training.ADAPTER_CONFIG in state_dict: - if self._model_type == ModelType.PHI3_MINI: - logger.warning( - "PEFT integration for Phi-3 Mini is not supported, skipping adapter config save" - ) - elif self._model_type == ModelType.LLAMA3_VISION: - logger.warning( - "PEFT integration for Llama3.2 Vision is not supported, skipping adapter config save" - ) - else: - state_dict[ - training.ADAPTER_CONFIG - ] = convert_weights.tune_to_peft_adapter_config( - state_dict[training.ADAPTER_CONFIG] - ) - output_path = Path.joinpath(self._output_dir, "adapter_config.json") - with open(output_path, "w") as f: - json.dump(state_dict[training.ADAPTER_CONFIG], f) - logger.info( - "Adapter checkpoint of size " - f"{os.path.getsize(output_path) / 1000**3:.2f} GB " - f"saved to {output_path}" - ) + state_dict[ + training.ADAPTER_CONFIG + ] = convert_weights.tune_to_peft_adapter_config( + state_dict[training.ADAPTER_CONFIG] + ) + output_path = Path.joinpath(self._output_dir, "adapter_config.json") + with open(output_path, "w") as f: + json.dump(state_dict[training.ADAPTER_CONFIG], f) + logger.info( + "Adapter checkpoint of size " + f"{os.path.getsize(output_path) / 1000**3:.2f} GB " + f"saved to {output_path}" + ) # If the recipe state needs to be output, first remove the model state dict # and if it exists, remove the adapter state dict as well