From dcc92ccedb063370d45b6df011ae55672b12276b Mon Sep 17 00:00:00 2001 From: Philip Bontrager Date: Thu, 31 Oct 2024 12:58:10 -0700 Subject: [PATCH 1/5] initial version --- .../llama3_2_vision/_convert_weights.py | 31 +++++++++ .../training/checkpointing/_checkpointer.py | 69 +++++++++++-------- 2 files changed, 73 insertions(+), 27 deletions(-) diff --git a/torchtune/models/llama3_2_vision/_convert_weights.py b/torchtune/models/llama3_2_vision/_convert_weights.py index ae3f025c91..3458951f0d 100644 --- a/torchtune/models/llama3_2_vision/_convert_weights.py +++ b/torchtune/models/llama3_2_vision/_convert_weights.py @@ -345,6 +345,7 @@ def llama3_vision_tune_to_hf( tile_size: int = 448, num_tiles: int = 4, supported_aspect_ratios: List[Tuple[int, int]] = None, + peft_dict: bool = False, ) -> Dict[str, torch.Tensor]: """ Convertor from Tune state dict to HF state dict. This handles: @@ -364,6 +365,8 @@ def llama3_vision_tune_to_hf( "decoder.tok_embeddings.fusion_embedding.weight": None, } inverted_mapping_dict.update(missing_keys) + if peft_dict: + inverted_mapping_dict = _get_peft_dict(inverted_mapping_dict) if head_dim is None: head_dim = dim // num_heads @@ -427,3 +430,31 @@ def _permute(t, n_heads): converted_state_dict[new_key] = value return converted_state_dict + + +# Mapping from torchtune LoRA module names to PEFT LoRA module names +_TO_PEFT_KEYS = { + "lora_a": "lora_A", + "lora_b": "lora_B", + "magnitude": "lora_magnitude_vector", +} + + +def _get_peft_dict(mapping_dict: Dict[str, str]) -> Dict[str, torch.Tensor]: + """ + Rather than recreate a separate mapping for LoRA adapter weights, we just + re-use the _FROM_HF mapping for base model weights. We iterate over it twice: + once to add mappings for LoRA A matrices and once to add mappings for LoRA B matrices. + """ + new_mapping_dict = {} + for k, v in _TO_PEFT_KEYS.items(): + new_mapping_dict.update( + { + vv.replace(".weight", f".{k}.weight"): kk.replace( + ".weight", f".{v}.weight" + ) + for kk, vv in mapping_dict.items() + if vv is not None + } + ) + return new_mapping_dict diff --git a/torchtune/training/checkpointing/_checkpointer.py b/torchtune/training/checkpointing/_checkpointer.py index 75a7fc950d..f826de8661 100644 --- a/torchtune/training/checkpointing/_checkpointer.py +++ b/torchtune/training/checkpointing/_checkpointer.py @@ -632,15 +632,35 @@ def save_checkpoint( "Saving Phi-3 Mini adapter weights to PEFT format is not supported, saving to torchtune format instead" ) else: - state_dict[ - training.ADAPTER_KEY - ] = convert_weights.tune_to_peft_adapter_weights( - state_dict[training.ADAPTER_KEY], - num_heads=self._config["num_attention_heads"], - num_kv_heads=self._config["num_key_value_heads"], - dim=self._config["hidden_size"], - head_dim=self._config.get("head_dim", None), - ) + if self._model_type == ModelType.LLAMA3_VISION: + state_dict[training.ADAPTER_KEY] = llama3_vision_tune_to_hf( + state_dict[training.MODEL_KEY], + num_heads=text_config["num_attention_heads"], + num_kv_heads=text_config["num_key_value_heads"], + dim=text_config["hidden_size"], + head_dim=text_config.get("head_dim", None), + vocab_size=text_config["vocab_size"], + cross_attention_layers=text_config.get( + "cross_attention_layers", None + ), + encoder_dim=vision_config["hidden_size"], + tile_size=vision_config["image_size"], + num_tiles=vision_config["max_num_tiles"], + supported_aspect_ratios=vision_config.get( + "supported_aspect_ratios", None + ), + peft_dict=True, + ) + else: + state_dict[ + training.ADAPTER_KEY + ] = convert_weights.tune_to_peft_adapter_weights( + state_dict[training.ADAPTER_KEY], + num_heads=self._config["num_attention_heads"], + num_kv_heads=self._config["num_key_value_heads"], + dim=self._config["hidden_size"], + head_dim=self._config.get("head_dim", None), + ) peft_output_path = Path.joinpath( self._output_dir, "adapter_model" ).with_suffix(".bin") @@ -656,24 +676,19 @@ def save_checkpoint( ) if training.ADAPTER_CONFIG in state_dict: - if self._model_type == ModelType.PHI3_MINI: - logger.warning( - "PEFT integration for Phi-3 Mini is not supported, skipping adapter config save" - ) - else: - state_dict[ - training.ADAPTER_CONFIG - ] = convert_weights.tune_to_peft_adapter_config( - state_dict[training.ADAPTER_CONFIG] - ) - output_path = Path.joinpath(self._output_dir, "adapter_config.json") - with open(output_path, "w") as f: - json.dump(state_dict[training.ADAPTER_CONFIG], f) - logger.info( - "Adapter checkpoint of size " - f"{os.path.getsize(output_path) / 1000**3:.2f} GB " - f"saved to {output_path}" - ) + state_dict[ + training.ADAPTER_CONFIG + ] = convert_weights.tune_to_peft_adapter_config( + state_dict[training.ADAPTER_CONFIG] + ) + output_path = Path.joinpath(self._output_dir, "adapter_config.json") + with open(output_path, "w") as f: + json.dump(state_dict[training.ADAPTER_CONFIG], f) + logger.info( + "Adapter checkpoint of size " + f"{os.path.getsize(output_path) / 1000**3:.2f} GB " + f"saved to {output_path}" + ) # If the recipe state needs to be output, first remove the model state dict # and if it exists, remove the adapter state dict as well From 1fe28258ae30b6df586cbb4bbe492bb475f6ecd3 Mon Sep 17 00:00:00 2001 From: Philip Bontrager Date: Thu, 31 Oct 2024 15:14:56 -0700 Subject: [PATCH 2/5] separate peft function --- .../llama3_2_vision/_convert_weights.py | 99 ++++++++++++++++--- .../training/checkpointing/_checkpointer.py | 18 ++-- 2 files changed, 93 insertions(+), 24 deletions(-) diff --git a/torchtune/models/llama3_2_vision/_convert_weights.py b/torchtune/models/llama3_2_vision/_convert_weights.py index 3458951f0d..ffbad74ca0 100644 --- a/torchtune/models/llama3_2_vision/_convert_weights.py +++ b/torchtune/models/llama3_2_vision/_convert_weights.py @@ -345,7 +345,6 @@ def llama3_vision_tune_to_hf( tile_size: int = 448, num_tiles: int = 4, supported_aspect_ratios: List[Tuple[int, int]] = None, - peft_dict: bool = False, ) -> Dict[str, torch.Tensor]: """ Convertor from Tune state dict to HF state dict. This handles: @@ -365,8 +364,6 @@ def llama3_vision_tune_to_hf( "decoder.tok_embeddings.fusion_embedding.weight": None, } inverted_mapping_dict.update(missing_keys) - if peft_dict: - inverted_mapping_dict = _get_peft_dict(inverted_mapping_dict) if head_dim is None: head_dim = dim // num_heads @@ -440,21 +437,95 @@ def _permute(t, n_heads): } -def _get_peft_dict(mapping_dict: Dict[str, str]) -> Dict[str, torch.Tensor]: +def _get_peft_dict(tune_to_hf_dict: Dict[str, str]) -> Dict[str, torch.Tensor]: """ Rather than recreate a separate mapping for LoRA adapter weights, we just re-use the _FROM_HF mapping for base model weights. We iterate over it twice: once to add mappings for LoRA A matrices and once to add mappings for LoRA B matrices. """ new_mapping_dict = {} - for k, v in _TO_PEFT_KEYS.items(): - new_mapping_dict.update( - { - vv.replace(".weight", f".{k}.weight"): kk.replace( - ".weight", f".{v}.weight" - ) - for kk, vv in mapping_dict.items() - if vv is not None - } - ) + for tune_peft, hf_peft in _TO_PEFT_KEYS.items(): + for tune_key, hf_key in tune_to_hf_dict.items(): + if hf_key is None or hf_val is None: + continue + + if peft_key == "magnitude": + # e.g. attn.q_proj.magnitude -> attn.q_proj.lora_magnitude_vector + tune_adapter = tune_key.replace(".weight", f".{tune_peft}") + hf_adapter = hf_key.replace(".weight", f".{hf_peft}") + else: + # e.g. attn.q_proj.lora_a.weight -> attn.q_proj.lora_A.weight + tune_adapter = tune_key.replace(".weight", f".{tune_peft}.weight") + hf_adapter = hf_key.replace(".weight", f".{hf_peft}.weight") + + new_mapping_dict[tune_adapter] = hf_adapter return new_mapping_dict + + +def llama3_vision_tune_to_peft_adapter_weights( + state_dict: Dict[str, torch.Tensor], + num_heads: int = 32, + num_kv_heads: int = 32, + dim: int = 4096, + head_dim: int = None, + cross_attention_layers: Optional[List[int]] = None, +) -> Dict[str, torch.Tensor]: + """ + Convertor from Tune state dict to HF state dict. This handles: + - Updateing the cross attention layer numbers + - skip loading the rope embeddings + - reshaping q, k projections + """ + converted_state_dict = {} + inverted_mapping_dict = {v: k for k, v in _FROM_HF.items()} + # missing keys in _FROM_HF due to naming collisions + missing_keys = { + "decoder.layers.{}.fusion_layer.ca_norm.scale": "language_model.model.layers.{}.input_layernorm.weight", + "decoder.layers.{}.fusion_layer.mlp_norm.scale": "language_model.model.layers.{}.post_attention_layernorm.weight", + "decoder.layers.{}.fusion_layer.mlp.w1.weight": "language_model.model.layers.{}.mlp.gate_proj.weight", + "decoder.layers.{}.fusion_layer.mlp.w3.weight": "language_model.model.layers.{}.mlp.up_proj.weight", + "decoder.layers.{}.fusion_layer.mlp.w2.weight": "language_model.model.layers.{}.mlp.down_proj.weight", + "decoder.tok_embeddings.fusion_embedding.weight": None, + } + inverted_mapping_dict.update(missing_keys) + inverted_mapping_dict = _get_peft_dict(inverted_mapping_dict) + + if head_dim is None: + head_dim = dim // num_heads + if cross_attention_layers is None: + cross_attention_layers = [] + # convert hf layer numbers to tune numbers + cross_attention_layers = [ + l - i for i, l in enumerate(sorted(cross_attention_layers)) + ] + + def _permute_lora_matrix(t, n_heads): + rank = t.shape[-1] + return ( + t.view(n_heads, head_dim // 2, 2, rank) + .transpose(1, 2) + .reshape((head_dim * n_heads), rank) + ) + + for key, value in state_dict.items(): + # if key == "decoder.layers.3.layer.attn.q_proj.lora_a.weight": + # import pdb; pdb.set_trace() + new_key = get_mapped_key(key, inverted_mapping_dict) + if "decoder" in key: + if "layers" in key: # Update layer numbers + layer = int(key.split(".")[2]) + num_shifts = sum(layer > l for l in cross_attention_layers) + new_layer = layer + num_shifts + key_lst = new_key.split(".") + if layer in cross_attention_layers and "fusion_layer" not in key: + new_layer += 1 # hf treats the fusion_layer as an additional layer + key_lst[3] = str(new_layer) + new_key = ".".join(key_lst) + if "q_proj" in key and "lora_B" in new_key and "cross_attn" not in new_key: + value = _permute_lora_matrix(value, num_heads) + elif ( + "k_proj" in key and "lora_B" in new_key and "cross_attn" not in new_key + ): + value = _permute_lora_matrix(value, num_kv_heads) + converted_state_dict["base_model.model." + new_key] = value + return converted_state_dict diff --git a/torchtune/training/checkpointing/_checkpointer.py b/torchtune/training/checkpointing/_checkpointer.py index 95168cc833..a239ab1cf1 100644 --- a/torchtune/training/checkpointing/_checkpointer.py +++ b/torchtune/training/checkpointing/_checkpointer.py @@ -637,23 +637,21 @@ def save_checkpoint( ) else: if self._model_type == ModelType.LLAMA3_VISION: - state_dict[training.ADAPTER_KEY] = llama3_vision_tune_to_hf( - state_dict[training.MODEL_KEY], + from torchtune.models.llama3_2_vision._convert_weights import ( + llama3_vision_tune_to_peft_adapter_weights, + ) + + state_dict[ + training.ADAPTER_KEY + ] = llama3_vision_tune_to_peft_adapter_weights( + state_dict[training.ADAPTER_KEY], num_heads=text_config["num_attention_heads"], num_kv_heads=text_config["num_key_value_heads"], dim=text_config["hidden_size"], head_dim=text_config.get("head_dim", None), - vocab_size=text_config["vocab_size"], cross_attention_layers=text_config.get( "cross_attention_layers", None ), - encoder_dim=vision_config["hidden_size"], - tile_size=vision_config["image_size"], - num_tiles=vision_config["max_num_tiles"], - supported_aspect_ratios=vision_config.get( - "supported_aspect_ratios", None - ), - peft_dict=True, ) else: state_dict[ From 6be3747e6bb39afdc26aa36658e064c0f148881a Mon Sep 17 00:00:00 2001 From: Philip Bontrager Date: Tue, 5 Nov 2024 12:24:06 -0800 Subject: [PATCH 3/5] testing broken checkpoints --- .../llama3_2_vision/11B_lora_single_device.yaml | 2 +- .../llama3_2_vision/11B_qlora_single_device.yaml | 2 +- recipes/lora_finetune_single_device.py | 16 +++++++++++++++- test.py | 15 +++++++++++++++ .../models/llama3_2_vision/_convert_weights.py | 4 ++-- .../training/checkpointing/_checkpointer.py | 1 + 6 files changed, 35 insertions(+), 5 deletions(-) create mode 100644 test.py diff --git a/recipes/configs/llama3_2_vision/11B_lora_single_device.yaml b/recipes/configs/llama3_2_vision/11B_lora_single_device.yaml index 12e90d836c..121a50416e 100644 --- a/recipes/configs/llama3_2_vision/11B_lora_single_device.yaml +++ b/recipes/configs/llama3_2_vision/11B_lora_single_device.yaml @@ -45,7 +45,7 @@ checkpointer: output_dir: /tmp/Llama-3.2-11B-Vision-Instruct/ model_type: LLAMA3_VISION resume_from_checkpoint: False -save_adapter_weights_only: False # PeFT formatting not available yet. This will save it in torchtune format only. +save_adapter_weights_only: False # Dataset dataset: diff --git a/recipes/configs/llama3_2_vision/11B_qlora_single_device.yaml b/recipes/configs/llama3_2_vision/11B_qlora_single_device.yaml index 82da762539..5a02ff2406 100644 --- a/recipes/configs/llama3_2_vision/11B_qlora_single_device.yaml +++ b/recipes/configs/llama3_2_vision/11B_qlora_single_device.yaml @@ -45,7 +45,7 @@ checkpointer: output_dir: /tmp/Llama-3.2-11B-Vision-Instruct/ model_type: LLAMA3_VISION resume_from_checkpoint: False -save_adapter_weights_only: False # PeFT formatting not available yet. This will save it in torchtune format only. +save_adapter_weights_only: False # Dataset dataset: diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index 979728596b..c28b1ebb37 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -583,7 +583,21 @@ def save_checkpoint(self, epoch: int) -> None: } ) - adapter_state_dict = {k: v.cpu() for k, v in self.adapter_params.items()} + # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! Testing remove this !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + adapter_state_dict = {} + for k, v in self._model.named_modules(): + if hasattr(v, "adapter_params") and callable(v.adapter_params): + import pdb + + pdb.set_trace() + adapter_params = v.adapter_params() + for n, p in v.state_dict().items(): + if any(n.endswith(param) for param in adapter_params): + full_key = f"{k}.{n}" + adapter_state_dict[n] = p.cpu() + + # adapter_state_dict = {k: v.cpu() for k, v in self.adapter_params.items()} + # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! End Testing !!!!!!!! !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict}) if not self._save_adapter_weights_only: diff --git a/test.py b/test.py new file mode 100644 index 0000000000..42abd9e479 --- /dev/null +++ b/test.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# TODO: (philip) remove after tests + +from transformers import AutoModelForCausalLM # , AutoTokenizer + +model_id = "meta-llama/Llama-3.2-11B-Vision" +peft_model_id = "/tmp/Llama-3.2-11B-Vision-Instruct/" + +model = AutoModelForCausalLM.from_pretrained(model_id) +model.load_adapter(peft_model_id) diff --git a/torchtune/models/llama3_2_vision/_convert_weights.py b/torchtune/models/llama3_2_vision/_convert_weights.py index ffbad74ca0..7afdf5b4ef 100644 --- a/torchtune/models/llama3_2_vision/_convert_weights.py +++ b/torchtune/models/llama3_2_vision/_convert_weights.py @@ -446,10 +446,10 @@ def _get_peft_dict(tune_to_hf_dict: Dict[str, str]) -> Dict[str, torch.Tensor]: new_mapping_dict = {} for tune_peft, hf_peft in _TO_PEFT_KEYS.items(): for tune_key, hf_key in tune_to_hf_dict.items(): - if hf_key is None or hf_val is None: + if hf_key is None or tune_key is None: continue - if peft_key == "magnitude": + if tune_peft == "magnitude": # e.g. attn.q_proj.magnitude -> attn.q_proj.lora_magnitude_vector tune_adapter = tune_key.replace(".weight", f".{tune_peft}") hf_adapter = hf_key.replace(".weight", f".{hf_peft}") diff --git a/torchtune/training/checkpointing/_checkpointer.py b/torchtune/training/checkpointing/_checkpointer.py index a239ab1cf1..91e74649b9 100644 --- a/torchtune/training/checkpointing/_checkpointer.py +++ b/torchtune/training/checkpointing/_checkpointer.py @@ -615,6 +615,7 @@ def save_checkpoint( ) if training.ADAPTER_KEY in state_dict: + # import pdb; pdb.set_trace() # Save torchtune format adapter weights even if we save PEFT format # This way we can resume no matter what (and memory footprint of adapter weights is small) output_path = Path.joinpath( From 861a68a1733f9f6f203b5e898211749c1cac4bcf Mon Sep 17 00:00:00 2001 From: Philip Bontrager Date: Wed, 13 Nov 2024 11:13:23 -0800 Subject: [PATCH 4/5] remove test --- test.py | 15 --------------- 1 file changed, 15 deletions(-) delete mode 100644 test.py diff --git a/test.py b/test.py deleted file mode 100644 index 42abd9e479..0000000000 --- a/test.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# TODO: (philip) remove after tests - -from transformers import AutoModelForCausalLM # , AutoTokenizer - -model_id = "meta-llama/Llama-3.2-11B-Vision" -peft_model_id = "/tmp/Llama-3.2-11B-Vision-Instruct/" - -model = AutoModelForCausalLM.from_pretrained(model_id) -model.load_adapter(peft_model_id) From 8d960b5dc5a8fb78c92229b0b1c257cf33d47784 Mon Sep 17 00:00:00 2001 From: Philip Bontrager Date: Wed, 13 Nov 2024 11:16:01 -0800 Subject: [PATCH 5/5] remove comment --- torchtune/models/llama3_2_vision/_convert_weights.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torchtune/models/llama3_2_vision/_convert_weights.py b/torchtune/models/llama3_2_vision/_convert_weights.py index 7afdf5b4ef..1957d354d9 100644 --- a/torchtune/models/llama3_2_vision/_convert_weights.py +++ b/torchtune/models/llama3_2_vision/_convert_weights.py @@ -508,8 +508,6 @@ def _permute_lora_matrix(t, n_heads): ) for key, value in state_dict.items(): - # if key == "decoder.layers.3.layer.attn.q_proj.lora_a.weight": - # import pdb; pdb.set_trace() new_key = get_mapped_key(key, inverted_mapping_dict) if "decoder" in key: if "layers" in key: # Update layer numbers