From bedce7ebb03159f9842ab36776d1fe8b2838d375 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Thu, 11 Jul 2024 12:03:50 +0100 Subject: [PATCH 1/5] generalizing reward models --- docs/source/api_ref_models.rst | 9 +- tests/torchtune/utils/test_checkpointer.py | 2 +- torchtune/models/llama2/__init__.py | 18 +- .../models/llama2/_component_builders.py | 225 +++++++++++++++++- torchtune/models/llama2/_model_builders.py | 83 +++++++ torchtune/models/mistral/__init__.py | 16 +- .../models/mistral/_component_builders.py | 1 + torchtune/models/mistral/_model_builders.py | 20 +- torchtune/modules/rlhf/__init__.py | 5 + torchtune/modules/rlhf/utils/__init__.py | 12 + .../rlhf/utils}/_convert_weights.py | 18 +- .../utils/_checkpointing/_checkpointer.py | 13 +- .../_checkpointing/_checkpointer_utils.py | 5 +- 13 files changed, 375 insertions(+), 52 deletions(-) create mode 100644 torchtune/modules/rlhf/__init__.py create mode 100644 torchtune/modules/rlhf/utils/__init__.py rename torchtune/{models/mistral => modules/rlhf/utils}/_convert_weights.py (90%) diff --git a/docs/source/api_ref_models.rst b/docs/source/api_ref_models.rst index cb3dd7fb67..089805ec46 100644 --- a/docs/source/api_ref_models.rst +++ b/docs/source/api_ref_models.rst @@ -56,6 +56,9 @@ Pre-trained models can be downloaded from the Hugging Face Hub with the followin llama2.qlora_llama2_70b llama2.llama2_tokenizer llama2.Llama2Tokenizer + llama2.llama2_reward_7b + llama2.lora_llama2_reward_7b + llama2.qlora_llama2_reward_7b code llama @@ -124,9 +127,9 @@ Pre-trained models can be downloaded from the Hugging Face Hub with the followin mistral.mistral_7b mistral.lora_mistral_7b mistral.qlora_mistral_7b - mistral.mistral_classifier_7b - mistral.lora_mistral_classifier_7b - mistral.qlora_mistral_classifier_7b + mistral.mistral_reward_7b + mistral.lora_mistral_reward_7b + mistral.qlora_mistral_reward_7b mistral.mistral_tokenizer mistral.MistralTokenizer diff --git a/tests/torchtune/utils/test_checkpointer.py b/tests/torchtune/utils/test_checkpointer.py index 71161a2460..ca8274005d 100644 --- a/tests/torchtune/utils/test_checkpointer.py +++ b/tests/torchtune/utils/test_checkpointer.py @@ -528,7 +528,7 @@ def single_file_checkpointer( return FullModelHFCheckpointer( checkpoint_dir=tmp_path, checkpoint_files=[checkpoint_file], - model_type="MISTRAL_REWARD", + model_type="REWARD", output_dir=tmp_path, ) diff --git a/torchtune/models/llama2/__init__.py b/torchtune/models/llama2/__init__.py index 83bebaf478..8ab4ee0dea 100644 --- a/torchtune/models/llama2/__init__.py +++ b/torchtune/models/llama2/__init__.py @@ -4,19 +4,27 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from ._component_builders import llama2, lora_llama2 +from ._component_builders import ( + llama2, + llama2_classifier, + lora_llama2, + lora_llama2_classifier, +) from ._model_builders import ( # noqa llama2_13b, llama2_70b, llama2_7b, + llama2_reward_7b, llama2_tokenizer, lora_llama2_13b, lora_llama2_70b, lora_llama2_7b, + lora_llama2_reward_7b, qlora_llama2_13b, qlora_llama2_70b, qlora_llama2_7b, + qlora_llama2_reward_7b, ) from ._model_utils import scale_hidden_dim_for_mlp from ._tokenizer import Llama2Tokenizer @@ -24,6 +32,14 @@ __all__ = [ "Llama2Tokenizer", "llama2", + "llama2_classifier_7b", + "llama2_classifier", + "lora_llama2_classifier", + "lora_llama2_classifier_7b", + "qlora_llama2_classifier", + "llama2_reward_7b", + "lora_llama2_reward_7b", + "qlora_llama2_reward_7b", "lora_llama2", "llama2_13b", "llama2_70b", diff --git a/torchtune/models/llama2/_component_builders.py b/torchtune/models/llama2/_component_builders.py index f9420a5682..80e7b7220d 100644 --- a/torchtune/models/llama2/_component_builders.py +++ b/torchtune/models/llama2/_component_builders.py @@ -96,9 +96,7 @@ def llama2( max_seq_len=max_seq_len, attn_dropout=attn_dropout, ) - hidden_dim = ( - intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim) - ) + hidden_dim = intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim) mlp = llama2_mlp(dim=embed_dim, hidden_dim=hidden_dim) layer = TransformerDecoderLayer( attn=self_attn, @@ -208,9 +206,7 @@ def lora_llama2( quantize_base=quantize_base, ) - hidden_dim = ( - intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim) - ) + hidden_dim = intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim) if apply_lora_to_mlp: mlp = lora_llama2_mlp( dim=embed_dim, @@ -312,9 +308,7 @@ def lora_llama2_self_attention( ValueError: If lora_modules arg is an empty list """ if not lora_modules: - raise ValueError( - f"Must pass one or more of {LORA_ATTN_MODULES} as lora_modules" - ) + raise ValueError(f"Must pass one or more of {LORA_ATTN_MODULES} as lora_modules") head_dim = embed_dim // num_heads num_kv_heads = num_kv_heads if num_kv_heads else num_heads @@ -422,3 +416,216 @@ def lora_llama2_mlp( down_proj=down_proj, up_proj=up_proj, ) + + +# ------------------ Llama2 Classifier ------------------ + + +def llama2_classifier( + num_classes: int, + *, + vocab_size: int, + num_layers: int, + num_heads: int, + num_kv_heads: int, + embed_dim: int, + max_seq_len: int, + attn_dropout: float = 0.0, + intermediate_dim: Optional[int] = None, + norm_eps: float = 1e-5, +) -> TransformerDecoder: + """ + Build a base Llama2 model with the final projection replaced with a classification layer. + + Args: + num_classes (int): number of classes for classification. + vocab_size (int): number of tokens in vocabulary. + num_layers (int): number of layers in the transformer decoder. + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value + num_kv_heads (int): number of key and value heads. If specified, + user should ensure `num_heads` % `num_kv_heads` == 0. Default value is + `None`, in which case this is the same as MHA + embed_dim (int): embedding dimension for self-attention + max_seq_len (int): maximum sequence length the model will be run with, as used + by :func:`~torchtune.modules.KVCache` + attn_dropout (float): dropout value passed onto scaled_dot_product_attention. + Default: 0.0 + intermediate_dim (Optional[int]): intermediate dimension for MLP. If not specified, + this is computed using :func:`~torchtune.modules.scale_hidden_dim_for_mlp` + norm_eps (float): epsilon in RMS norms. + + Returns: + TransformerDecoder: Instantiation of Llama2 model. + """ + head_dim = embed_dim // num_heads + num_kv_heads = num_kv_heads if num_kv_heads else num_heads + + rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len) + self_attn = CausalSelfAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), + k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + output_proj=nn.Linear(embed_dim, embed_dim, bias=False), + pos_embeddings=rope, + kv_cache=None, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + ) + hidden_dim = intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim) + mlp = llama2_mlp(dim=embed_dim, hidden_dim=hidden_dim) + layer = TransformerDecoderLayer( + attn=self_attn, + mlp=mlp, + sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + ) + tok_embeddings = nn.Embedding(vocab_size, embed_dim) + output_proj = nn.Linear(embed_dim, num_classes, bias=False) + return TransformerDecoder( + tok_embeddings=tok_embeddings, + layer=layer, + num_layers=num_layers, + max_seq_len=max_seq_len, + num_heads=num_heads, + head_dim=head_dim, + norm=RMSNorm(embed_dim, eps=norm_eps), + output=output_proj, + ) + + +def lora_llama2_classifier( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + apply_lora_to_output: bool = False, + *, + # llama2 classifier args, + num_classes: int, + # llama2 args + vocab_size: int, + num_layers: int, + num_heads: int, + num_kv_heads: int, + embed_dim: int, + max_seq_len: int, + intermediate_dim: Optional[int] = None, + attn_dropout: float = 0.0, + norm_eps: float = 1e-5, + # LoRA args + lora_rank: int, + lora_alpha: float, + lora_dropout: float = 0.0, + # Quantization args + quantize_base: bool = False, +) -> TransformerDecoder: + """ + Return a version of Llama2 (an instance of :func:`~torchtune.modules.TransformerDecoder`) + with LoRA applied based on the passed in configuration. + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. + Default: False + num_classes (int): number of classes for classification. + vocab_size (int): number of tokens in vocabulary. + num_layers (int): number of layers in the transformer decoder. + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value + num_kv_heads (int): number of key and value heads. User should ensure + `num_heads` % `num_kv_heads` == 0. For standard MHA set `num_kv_heads` == `num_heads`, + for GQA `num_kv_heads` < `num_heads`, and for MQA set `num_kv_heads` == 1. + embed_dim (int): embedding dimension for self-attention + max_seq_len (int): maximum sequence length the model will be run with, as used + by :func:`~torchtune.modules.KVCache` + attn_dropout (float): dropout value passed onto scaled_dot_product_attention. + Default: 0.0 + intermediate_dim (Optional[int]): intermediate dimension for MLP. If not specified, + this is computed using :func:`~torchtune.modules.scale_hidden_dim_for_mlp` + norm_eps (float): epsilon in RMS norms. + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + lora_dropout (float): LoRA dropout probability. Default: 0.0 + quantize_base: (bool): Whether to quantize base model weights or not. Only applied to base + weights within linear layers LoRA is applied to. The final output linear projection is not + supported for quantization currently. + + Returns: + TransformerDecoder: Instantiation of Llama2 model with LoRA applied to + a subset of the attention projections in each layer. + + """ + + self_attn = lora_llama2_self_attention( + lora_modules=lora_attn_modules, + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + quantize_base=quantize_base, + ) + + hidden_dim = intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim) + if apply_lora_to_mlp: + mlp = lora_llama2_mlp( + dim=embed_dim, + hidden_dim=hidden_dim, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + quantize_base=quantize_base, + lora_dropout=lora_dropout, + ) + else: + mlp = llama2_mlp(dim=embed_dim, hidden_dim=hidden_dim) + + layer = TransformerDecoderLayer( + attn=self_attn, + mlp=mlp, + sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + ) + + tok_embeddings = nn.Embedding(vocab_size, embed_dim) + + # TODO: quantize_base is not applied to final output_proj currently. + output_proj = ( + LoRALinear(embed_dim, num_classes, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout) + if apply_lora_to_output + else nn.Linear(embed_dim, num_classes, bias=False) + ) + model = TransformerDecoder( + tok_embeddings=tok_embeddings, + layer=layer, + num_layers=num_layers, + max_seq_len=max_seq_len, + num_heads=num_heads, + head_dim=(embed_dim // num_heads), + norm=RMSNorm(embed_dim, eps=norm_eps), + output=output_proj, + ) + + if quantize_base: + # For QLoRA, we reparametrize 4-bit tensors to higher precision, and offload to CPU on the fly + # so as to not increase peak memory + model._register_state_dict_hook( + partial( + reparametrize_as_dtype_state_dict_post_hook, + # TODO this is clowny, figure out a better way to get what precision the rest + # of the model is in + dtype=tok_embeddings.weight.dtype, + offload_to_cpu=True, + ) + ) + + return model diff --git a/torchtune/models/llama2/_model_builders.py b/torchtune/models/llama2/_model_builders.py index 8e21ced41a..e061332f70 100644 --- a/torchtune/models/llama2/_model_builders.py +++ b/torchtune/models/llama2/_model_builders.py @@ -39,6 +39,7 @@ def llama2_7b() -> TransformerDecoder: norm_eps=1e-5, ) + def llama2_tokenizer(path: str) -> Llama2Tokenizer: """ Tokenizer for Llama2. @@ -263,9 +264,91 @@ def lora_llama2_70b( quantize_base=quantize_base, ) + qlora_llama2_70b = partial(lora_llama2_70b, quantize_base=True) qlora_llama2_70b.__doc__ = """ Builder for creating a Llama2 70B model with QLoRA enabled. Base model weights in linear layers that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. Please see `lora_llama2_70b` for full API arguments. """ + + +def llama2_reward_7b() -> TransformerDecoder: + """ + Builder for creating a Llama2 model initialized w/ the default 7B parameter values + from https://arxiv.org/abs/2307.09288, where the output layer is a classification layer + projecting to a single class for reward modelling. + + Returns: + TransformerDecoder: Instantiation of Llama2 7B model + """ + return llama2_classifier_7b( + num_classes=1, + vocab_size=32_000, + num_layers=32, + num_heads=32, + num_kv_heads=32, + embed_dim=4096, + max_seq_len=4096, + attn_dropout=0.0, + norm_eps=1e-5, + ) + + +def lora_llama2_reward_7b( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + apply_lora_to_output: bool = False, + lora_rank: int = 8, + lora_alpha: float = 16, + lora_dropout: float = 0.05, + quantize_base: bool = False, +) -> TransformerDecoder: + """ + Builder for creating a Llama2 7B reward model with LoRA enabled. + + The Llama2 classifier defaults are the same as in :func:`~torchtune.models.llama2.llama2_reward_7b`, + while LoRA default params are based on + https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43. + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. + Default: False + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + quantize_base (bool): Whether to quantize base model weights + + Returns: + TransformerDecoder: Instantiation of Llama2 7B model with LoRA applied + """ + return lora_llama2( + lora_attn_modules=lora_attn_modules, + apply_lora_to_mlp=apply_lora_to_mlp, + apply_lora_to_output=apply_lora_to_output, + num_classes=1, + vocab_size=32_000, + num_layers=32, + num_heads=32, + num_kv_heads=32, + embed_dim=4096, + max_seq_len=4096, + attn_dropout=0.0, + norm_eps=1e-5, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + quantize_base=quantize_base, + ) + + +qlora_llama2_reward_7b = partial(lora_llama2_7b, quantize_base=True) +qlora_llama2_reward_7b.__doc__ = """ +Builder for creating a Llama2 reward 7b model with QLoRA enabled. Base model weights in linear layers +that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. +Please see `lora_llama2_reward_7b` for full API arguments. +""" diff --git a/torchtune/models/mistral/__init__.py b/torchtune/models/mistral/__init__.py index 31149461a1..d02a5cac80 100644 --- a/torchtune/models/mistral/__init__.py +++ b/torchtune/models/mistral/__init__.py @@ -10,18 +10,14 @@ mistral, mistral_classifier, ) -from ._convert_weights import ( # noqa - mistral_reward_hf_to_tune, - mistral_reward_tune_to_hf, -) from ._model_builders import ( lora_mistral_7b, - lora_mistral_classifier_7b, + lora_mistral_reward_7b, mistral_7b, - mistral_classifier_7b, + mistral_reward_7b, mistral_tokenizer, qlora_mistral_7b, - qlora_mistral_classifier_7b, + qlora_mistral_reward_7b, ) from ._tokenizer import MistralTokenizer @@ -34,10 +30,10 @@ "mistral_reward_hf_to_tune", "mistral_reward_tune_to_hf", "lora_mistral_7b", - "lora_mistral_classifier_7b", + "lora_mistral_reward_7b", "mistral_7b", - "mistral_classifier_7b", + "mistral_reward_7b", "mistral_tokenizer", "qlora_mistral_7b", - "qlora_mistral_classifier_7b", + "qlora_mistral_reward_7b", ] diff --git a/torchtune/models/mistral/_component_builders.py b/torchtune/models/mistral/_component_builders.py index bae85fe00a..ece4852db9 100644 --- a/torchtune/models/mistral/_component_builders.py +++ b/torchtune/models/mistral/_component_builders.py @@ -526,6 +526,7 @@ def lora_mistral_classifier( Default: False apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. Default: False + num_classes (int): number of classes for the classification layer. vocab_size (int): number of tokens in vocabulary. num_layers (int): number of layers in the transformer decoder. num_heads (int): number of query heads. For MHA this is also the diff --git a/torchtune/models/mistral/_model_builders.py b/torchtune/models/mistral/_model_builders.py index 360d3112fc..44b7f520ce 100644 --- a/torchtune/models/mistral/_model_builders.py +++ b/torchtune/models/mistral/_model_builders.py @@ -45,6 +45,7 @@ def mistral_7b() -> TransformerDecoder: norm_eps=1e-5, ) + def mistral_tokenizer(path: str) -> MistralTokenizer: """ Tokenizer for Mistral models. @@ -57,6 +58,7 @@ def mistral_tokenizer(path: str) -> MistralTokenizer: """ return MistralTokenizer(path) + def lora_mistral_7b( lora_attn_modules: List[LORA_ATTN_MODULES], apply_lora_to_mlp: bool = False, @@ -113,12 +115,12 @@ def lora_mistral_7b( """ -def mistral_classifier_7b() -> TransformerDecoder: +def mistral_reward_7b() -> TransformerDecoder: """ - Builder for creating a Mistral 7B classifier model initialized w/ the default 7b + Builder for creating a Mistral 7B model initialized w/ the default 7b parameter values from: https://huggingface.co/Ray2333/reward-model-Mistral-7B-instruct-Unified-Feedback - + where the output layer is a classification layer projecting to a single class for reward modelling. Returns: TransformerClassifier: Instantiation of Mistral 7B classifier model @@ -137,7 +139,7 @@ def mistral_classifier_7b() -> TransformerDecoder: ) -def lora_mistral_classifier_7b( +def lora_mistral_reward_7b( lora_attn_modules: List[LORA_ATTN_MODULES], apply_lora_to_mlp: bool = False, apply_lora_to_output: bool = False, @@ -146,7 +148,7 @@ def lora_mistral_classifier_7b( quantize_base: bool = False, ) -> TransformerDecoder: """ - Builder for creating a Mistral classifier 7B model with LoRA enabled. + Builder for creating a Mistral reward 7B model with LoRA enabled. Args: lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers @@ -185,10 +187,10 @@ def lora_mistral_classifier_7b( ) -qlora_mistral_classifier_7b = partial(lora_mistral_classifier_7b, quantize_base=True) +qlora_mistral_reward_7b = partial(lora_mistral_reward_7b, quantize_base=True) -qlora_mistral_classifier_7b.__doc__ = """ -Builder for creating a Mistral classifier model with QLoRA enabled. Base model weights in linear layers +qlora_mistral_reward_7b.__doc__ = """ +Builder for creating a Mistral reward 7B model with QLoRA enabled. Base model weights in linear layers that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. -Please see `lora_mistral_classifier_7b` for full API arguments. +Please see `lora_mistral_reward_7b` for full API arguments. """ diff --git a/torchtune/modules/rlhf/__init__.py b/torchtune/modules/rlhf/__init__.py new file mode 100644 index 0000000000..2e41cd717f --- /dev/null +++ b/torchtune/modules/rlhf/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torchtune/modules/rlhf/utils/__init__.py b/torchtune/modules/rlhf/utils/__init__.py new file mode 100644 index 0000000000..e8fa398f2f --- /dev/null +++ b/torchtune/modules/rlhf/utils/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from ._convert_weights import reward_hf_to_tune, reward_tune_to_hf # noqa + +__all__ = [ + "reward_hf_to_tune", + "reward_tune_to_hf", +] diff --git a/torchtune/models/mistral/_convert_weights.py b/torchtune/modules/rlhf/utils/_convert_weights.py similarity index 90% rename from torchtune/models/mistral/_convert_weights.py rename to torchtune/modules/rlhf/utils/_convert_weights.py index e9ac3b6c13..37263c917a 100644 --- a/torchtune/models/mistral/_convert_weights.py +++ b/torchtune/modules/rlhf/utils/_convert_weights.py @@ -10,7 +10,7 @@ from torchtune.models.convert_weights import get_mapped_key -_MISTRAL_REWARD = { +_REWARD = { "model.embed_tokens.weight": "tok_embeddings.weight", "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attn.q_proj.weight", "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attn.k_proj.weight", @@ -26,7 +26,7 @@ } -def mistral_reward_hf_to_tune( +def reward_hf_to_tune( state_dict: Dict[str, torch.Tensor], num_heads: int = 32, num_kv_heads: int = 32, @@ -35,7 +35,7 @@ def mistral_reward_hf_to_tune( ) -> Dict[str, torch.Tensor]: """ Convert a state dict from HF's format to torchtune's format, which contains the weights - of a Mistral reward model. + of a reward model (i.e. a classifier with a single class). State dicts from multiple checkpoint files should be consolidated into a single state dict before calling this function. The logic is identical to :func:`~torchtune.models.convert_weights.hf_to_tune`, but with a different mapping. @@ -67,7 +67,7 @@ def _permute(t, n_heads): for key, value in state_dict.items(): if "rotary_emb.inv_freq" not in key: # Skip loading the position embeddings - new_key = get_mapped_key(key, _MISTRAL_REWARD) + new_key = get_mapped_key(key, _REWARD) if "q_proj" in key: value = _permute(value, num_heads) elif "k_proj" in key: @@ -76,17 +76,17 @@ def _permute(t, n_heads): return converted_state_dict -def mistral_reward_tune_to_hf( +def reward_tune_to_hf( state_dict: Dict[str, torch.Tensor], num_heads: int = 32, num_kv_heads: int = 32, dim: int = 4096, ) -> Dict[str, torch.Tensor]: """ - Convert a state dict from torchtune's format to Hugging Face's format for a Mistral reward model. + Convert a state dict from torchtune's format to Hugging Face's format for a reward model. - This function takes a state dictionary in torchtune's format, which contains the weights of a Mistral reward model, - and converts it into a format that can be loaded into a Hugging Face model. + This function takes a state dictionary in torchtune's format, which contains the weights of a reward model + (i.e. a classifier with a single class), and converts it into a format that can be loaded into a Hugging Face model. The logic is identical to :func:`~torchtune.models.convert_weights.tune_to_hf`, but with a different mapping. Args: @@ -100,7 +100,7 @@ def mistral_reward_tune_to_hf( """ converted_state_dict = {} - inverted_mapping_dict = {v: k for k, v in _MISTRAL_REWARD.items()} + inverted_mapping_dict = {v: k for k, v in _REWARD.items()} head_dim = dim // num_heads def _permute(t, n_heads): diff --git a/torchtune/utils/_checkpointing/_checkpointer.py b/torchtune/utils/_checkpointing/_checkpointer.py index 107154230b..80c2e272c7 100644 --- a/torchtune/utils/_checkpointing/_checkpointer.py +++ b/torchtune/utils/_checkpointing/_checkpointer.py @@ -17,11 +17,8 @@ from torchtune.models import convert_weights from torchtune.models.gemma import gemma_hf_to_tune, gemma_tune_to_hf -from torchtune.models.mistral import ( - mistral_reward_hf_to_tune, - mistral_reward_tune_to_hf, -) from torchtune.models.phi3 import phi3_hf_to_tune, phi3_tune_to_hf +from torchtune.modules.rlhf.utils import reward_hf_to_tune, reward_tune_to_hf from torchtune.utils._checkpointing._checkpointer_utils import ( get_path, ModelType, @@ -420,8 +417,8 @@ def load_checkpoint(self) -> Dict[str, Any]: "Note that conversion of adapter weights into PEFT format is not supported." ) converted_state_dict[utils.MODEL_KEY] = phi3_hf_to_tune(merged_state_dict) - elif self._model_type == ModelType.MISTRAL_REWARD: - converted_state_dict[utils.MODEL_KEY] = mistral_reward_hf_to_tune( + elif self._model_type == ModelType.REWARD: + converted_state_dict[utils.MODEL_KEY] = reward_hf_to_tune( merged_state_dict, num_heads=self._config["num_attention_heads"], num_kv_heads=self._config["num_key_value_heads"], @@ -478,8 +475,8 @@ def save_checkpoint( # convert the state_dict back to hf format; do this inplace if self._model_type == ModelType.PHI3_MINI: state_dict[utils.MODEL_KEY] = phi3_tune_to_hf(state_dict[utils.MODEL_KEY]) - elif self._model_type == ModelType.MISTRAL_REWARD: - state_dict[utils.MODEL_KEY] = mistral_reward_tune_to_hf( + elif self._model_type == ModelType.REWARD: + state_dict[utils.MODEL_KEY] = reward_tune_to_hf( state_dict[utils.MODEL_KEY], num_heads=self._config["num_attention_heads"], num_kv_heads=self._config["num_key_value_heads"], diff --git a/torchtune/utils/_checkpointing/_checkpointer_utils.py b/torchtune/utils/_checkpointing/_checkpointer_utils.py index 253e285144..fd7a691445 100644 --- a/torchtune/utils/_checkpointing/_checkpointer_utils.py +++ b/torchtune/utils/_checkpointing/_checkpointer_utils.py @@ -42,8 +42,9 @@ class ModelType(Enum): PHI3_MINI = "phi3_mini" """Phi-3 family of models. See :func:`~torchtune.models.phi3.phi3`""" - MISTRAL_REWARD = "mistral_reward" - """Mistral model with a classification head. See :func:`~torchtune.models.mistral.mistral_classifier`""" + REWARD = "reward" + """A Llama2, Llama3, or Mistral model with a classification head projecting to a single class for reward modelling. + See :func:`~torchtune.models.mistral.mistral_classifier` or :func:`~torchtune.models.llama2.llama2_classifier`""" def get_path(input_dir: Path, filename: str, missing_ok: bool = False) -> Path: From 415d32f5b68968fbfb98af3bed91e023da13d754 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Thu, 11 Jul 2024 12:10:20 +0100 Subject: [PATCH 2/5] updating docs --- torchtune/utils/_checkpointing/_checkpointer_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtune/utils/_checkpointing/_checkpointer_utils.py b/torchtune/utils/_checkpointing/_checkpointer_utils.py index fd7a691445..d01d552985 100644 --- a/torchtune/utils/_checkpointing/_checkpointer_utils.py +++ b/torchtune/utils/_checkpointing/_checkpointer_utils.py @@ -44,7 +44,7 @@ class ModelType(Enum): REWARD = "reward" """A Llama2, Llama3, or Mistral model with a classification head projecting to a single class for reward modelling. - See :func:`~torchtune.models.mistral.mistral_classifier` or :func:`~torchtune.models.llama2.llama2_classifier`""" + See :func:`~torchtune.models.mistral.mistral_reward_7b` or :func:`~torchtune.models.llama2.llama2_reward_7b`""" def get_path(input_dir: Path, filename: str, missing_ok: bool = False) -> Path: From 575b0de362b9aca94fd7329bc0a4fd38017de791 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Thu, 11 Jul 2024 12:41:59 +0100 Subject: [PATCH 3/5] updating weight conversion logic to remove output bias in HF state dicts, updating test --- tests/torchtune/utils/test_checkpointer.py | 10 ++++++---- torchtune/modules/rlhf/utils/_convert_weights.py | 5 ++++- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/tests/torchtune/utils/test_checkpointer.py b/tests/torchtune/utils/test_checkpointer.py index ca8274005d..e02ec23409 100644 --- a/tests/torchtune/utils/test_checkpointer.py +++ b/tests/torchtune/utils/test_checkpointer.py @@ -480,6 +480,8 @@ def state_dict(self, weight_dtype): ), "model.norm.weight": randn(_DIM, dtype=weight_dtype), "score.weight": randn(1, _DIM, dtype=weight_dtype), + # adding bias to ensure it doesn't cause an unexpected key + "score.bias": randn(1, _DIM, dtype=weight_dtype), } return state_dict @@ -554,12 +556,12 @@ def test_load_save_checkpoint_single_file( # Converted state dict from the checkpointer state_dict = single_file_checkpointer.load_checkpoint() - # Check that we've loaded all the keys - assert len(state_dict["model"].keys()) == len(orig_state_dict.keys()) + # Check that we've loaded all the keys minus the output bias + assert len(state_dict["model"].keys()) == len(orig_state_dict.keys()) - 1 # the keys in original state dict should match up with the keys in the weight_map for key in orig_state_dict.keys(): - if "inv_freq" in key: + if "inv_freq" in key or "output.bias" in key: continue assert key in single_file_checkpointer._weight_map @@ -584,7 +586,7 @@ def test_load_save_checkpoint_single_file( output_file = Path.joinpath(checkpoint_file.parent, "hf_model_0001_1.pt") output_state_dict = safe_torch_load(output_file) - assert len(output_state_dict.keys()) == len(orig_state_dict.keys()) + assert len(output_state_dict.keys()) == len(orig_state_dict.keys()) - 1 class TestHFGemmaFullModelCheckpointer: diff --git a/torchtune/modules/rlhf/utils/_convert_weights.py b/torchtune/modules/rlhf/utils/_convert_weights.py index 37263c917a..eae4a76ddb 100644 --- a/torchtune/modules/rlhf/utils/_convert_weights.py +++ b/torchtune/modules/rlhf/utils/_convert_weights.py @@ -66,7 +66,10 @@ def _permute(t, n_heads): ) for key, value in state_dict.items(): - if "rotary_emb.inv_freq" not in key: # Skip loading the position embeddings + # Skip loading the position embeddings and output bias + # these are not used in the reward model and some HF pipelines (e.g. TRL) + # may save them + if "rotary_emb.inv_freq" not in key and "score.bias" not in key: new_key = get_mapped_key(key, _REWARD) if "q_proj" in key: value = _permute(value, num_heads) From e0223aa1ac32f94a7afcbec3a25c715abd6d9a9b Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Thu, 11 Jul 2024 13:07:12 +0100 Subject: [PATCH 4/5] small bug in hf_to_tune --- torchtune/modules/rlhf/utils/_convert_weights.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/torchtune/modules/rlhf/utils/_convert_weights.py b/torchtune/modules/rlhf/utils/_convert_weights.py index eae4a76ddb..2e50138808 100644 --- a/torchtune/modules/rlhf/utils/_convert_weights.py +++ b/torchtune/modules/rlhf/utils/_convert_weights.py @@ -66,10 +66,12 @@ def _permute(t, n_heads): ) for key, value in state_dict.items(): - # Skip loading the position embeddings and output bias - # these are not used in the reward model and some HF pipelines (e.g. TRL) - # may save them - if "rotary_emb.inv_freq" not in key and "score.bias" not in key: + # ignore output layer bias - these are not used in the reward model + # and some HF pipelines (e.g. TRL) may save them + if key == "score.bias": + continue + # Skip loading the position embeddings + if "rotary_emb.inv_freq" not in key: new_key = get_mapped_key(key, _REWARD) if "q_proj" in key: value = _permute(value, num_heads) From e8e97458ade29a5036f8baa9461efab127ba2555 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Thu, 18 Jul 2024 09:39:16 +0100 Subject: [PATCH 5/5] fixing typos in model builders --- torchtune/models/llama2/__init__.py | 3 --- torchtune/models/llama2/_model_builders.py | 6 +++--- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/torchtune/models/llama2/__init__.py b/torchtune/models/llama2/__init__.py index 8ab4ee0dea..7f31e9e639 100644 --- a/torchtune/models/llama2/__init__.py +++ b/torchtune/models/llama2/__init__.py @@ -32,11 +32,8 @@ __all__ = [ "Llama2Tokenizer", "llama2", - "llama2_classifier_7b", "llama2_classifier", "lora_llama2_classifier", - "lora_llama2_classifier_7b", - "qlora_llama2_classifier", "llama2_reward_7b", "lora_llama2_reward_7b", "qlora_llama2_reward_7b", diff --git a/torchtune/models/llama2/_model_builders.py b/torchtune/models/llama2/_model_builders.py index e061332f70..334051440c 100644 --- a/torchtune/models/llama2/_model_builders.py +++ b/torchtune/models/llama2/_model_builders.py @@ -6,7 +6,7 @@ from typing import List from functools import partial -from torchtune.models.llama2._component_builders import llama2, lora_llama2 +from torchtune.models.llama2._component_builders import llama2, lora_llama2, llama2_classifier, lora_llama2_classifier from torchtune.modules import TransformerDecoder from torchtune.models.llama2._tokenizer import Llama2Tokenizer @@ -282,7 +282,7 @@ def llama2_reward_7b() -> TransformerDecoder: Returns: TransformerDecoder: Instantiation of Llama2 7B model """ - return llama2_classifier_7b( + return llama2_classifier( num_classes=1, vocab_size=32_000, num_layers=32, @@ -326,7 +326,7 @@ def lora_llama2_reward_7b( Returns: TransformerDecoder: Instantiation of Llama2 7B model with LoRA applied """ - return lora_llama2( + return lora_llama2_classifier( lora_attn_modules=lora_attn_modules, apply_lora_to_mlp=apply_lora_to_mlp, apply_lora_to_output=apply_lora_to_output,