Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalizing reward models #1160

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions docs/source/api_ref_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ To download the Llama2-70B model:
llama2.qlora_llama2_70b
llama2.llama2_tokenizer
llama2.Llama2Tokenizer
llama2.llama2_reward_7b
llama2.lora_llama2_reward_7b
llama2.qlora_llama2_reward_7b


code llama
Expand Down Expand Up @@ -179,9 +182,9 @@ To download the Mistral 7B v0.1 model:
mistral.mistral_7b
mistral.lora_mistral_7b
mistral.qlora_mistral_7b
mistral.mistral_classifier_7b
mistral.lora_mistral_classifier_7b
mistral.qlora_mistral_classifier_7b
mistral.mistral_reward_7b
mistral.lora_mistral_reward_7b
mistral.qlora_mistral_reward_7b
mistral.mistral_tokenizer
mistral.MistralTokenizer

Expand Down
12 changes: 7 additions & 5 deletions tests/torchtune/utils/test_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,8 @@ def state_dict(self, weight_dtype):
),
"model.norm.weight": randn(_DIM, dtype=weight_dtype),
"score.weight": randn(1, _DIM, dtype=weight_dtype),
# adding bias to ensure it doesn't cause an unexpected key
"score.bias": randn(1, _DIM, dtype=weight_dtype),
}
return state_dict

Expand Down Expand Up @@ -554,7 +556,7 @@ def single_file_checkpointer(
return FullModelHFCheckpointer(
checkpoint_dir=tmp_path,
checkpoint_files=[checkpoint_file],
model_type="MISTRAL_REWARD",
model_type="REWARD",
output_dir=tmp_path,
)

Expand All @@ -580,12 +582,12 @@ def test_load_save_checkpoint_single_file(

# Converted state dict from the checkpointer
state_dict = single_file_checkpointer.load_checkpoint()
# Check that we've loaded all the keys
assert len(state_dict["model"].keys()) == len(orig_state_dict.keys())
# Check that we've loaded all the keys minus the output bias
assert len(state_dict["model"].keys()) == len(orig_state_dict.keys()) - 1

# the keys in original state dict should match up with the keys in the weight_map
for key in orig_state_dict.keys():
if "inv_freq" in key:
if "inv_freq" in key or "output.bias" in key:
continue
assert key in single_file_checkpointer._weight_map

Expand All @@ -610,7 +612,7 @@ def test_load_save_checkpoint_single_file(
output_file = Path.joinpath(checkpoint_file.parent, "hf_model_0001_1.pt")
output_state_dict = safe_torch_load(output_file)

assert len(output_state_dict.keys()) == len(orig_state_dict.keys())
assert len(output_state_dict.keys()) == len(orig_state_dict.keys()) - 1


class TestHFGemmaFullModelCheckpointer:
Expand Down
15 changes: 14 additions & 1 deletion torchtune/models/llama2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,38 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from ._component_builders import llama2, lora_llama2
from ._component_builders import (
llama2,
llama2_classifier,
lora_llama2,
lora_llama2_classifier,
)

from ._model_builders import ( # noqa
llama2_13b,
llama2_70b,
llama2_7b,
llama2_reward_7b,
llama2_tokenizer,
lora_llama2_13b,
lora_llama2_70b,
lora_llama2_7b,
lora_llama2_reward_7b,
qlora_llama2_13b,
qlora_llama2_70b,
qlora_llama2_7b,
qlora_llama2_reward_7b,
)
from ._tokenizer import Llama2Tokenizer

__all__ = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seeing this huge list scares me a little bit, because every model size would have 6 variants: normal, classifier, lora, qlora, classifier_lora, classifier_qlora. If we add DoRA, this would have 3 more builders. I think you followed the pattern correctly. Maybe a question for @kartikayk and @ebsmothers. I guess we dont like hooks in torchtune, but something like replace_with_reward_head(model=llama3) seems convenient, instead of rebuilding llama3 completely.

"Llama2Tokenizer",
"llama2",
"llama2_classifier",
"lora_llama2_classifier",
"llama2_reward_7b",
"lora_llama2_reward_7b",
"qlora_llama2_reward_7b",
"lora_llama2",
"llama2_13b",
"llama2_70b",
Expand Down
225 changes: 216 additions & 9 deletions torchtune/models/llama2/_component_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,7 @@ def llama2(
max_seq_len=max_seq_len,
attn_dropout=attn_dropout,
)
hidden_dim = (
intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim)
)
hidden_dim = intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim)
mlp = llama2_mlp(dim=embed_dim, hidden_dim=hidden_dim)
layer = TransformerDecoderLayer(
attn=self_attn,
Expand Down Expand Up @@ -208,9 +206,7 @@ def lora_llama2(
quantize_base=quantize_base,
)

hidden_dim = (
intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim)
)
hidden_dim = intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim)
if apply_lora_to_mlp:
mlp = lora_llama2_mlp(
dim=embed_dim,
Expand Down Expand Up @@ -312,9 +308,7 @@ def lora_llama2_self_attention(
ValueError: If lora_modules arg is an empty list
"""
if not lora_modules:
raise ValueError(
f"Must pass one or more of {LORA_ATTN_MODULES} as lora_modules"
)
raise ValueError(f"Must pass one or more of {LORA_ATTN_MODULES} as lora_modules")

head_dim = embed_dim // num_heads
num_kv_heads = num_kv_heads if num_kv_heads else num_heads
Expand Down Expand Up @@ -422,3 +416,216 @@ def lora_llama2_mlp(
down_proj=down_proj,
up_proj=up_proj,
)


# ------------------ Llama2 Classifier ------------------


def llama2_classifier(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont want to overcomplicate things, but is "classifier" the right word?

My knowledge of RL is limited, but can the output of the LLM just be a scalar, like how "polite", "trustworthy", "helpful", "friendly" the model is?

In this case, would classifier still be the right naming? I have seen it in other places too, so I guess this may be the standard.

num_classes: int,
*,
vocab_size: int,
num_layers: int,
num_heads: int,
num_kv_heads: int,
embed_dim: int,
max_seq_len: int,
attn_dropout: float = 0.0,
intermediate_dim: Optional[int] = None,
norm_eps: float = 1e-5,
) -> TransformerDecoder:
"""
Build a base Llama2 model with the final projection replaced with a classification layer.

Args:
num_classes (int): number of classes for classification.
vocab_size (int): number of tokens in vocabulary.
num_layers (int): number of layers in the transformer decoder.
num_heads (int): number of query heads. For MHA this is also the
number of heads for key and value
num_kv_heads (int): number of key and value heads. If specified,
user should ensure `num_heads` % `num_kv_heads` == 0. Default value is
`None`, in which case this is the same as MHA
embed_dim (int): embedding dimension for self-attention
max_seq_len (int): maximum sequence length the model will be run with, as used
by :func:`~torchtune.modules.KVCache`
attn_dropout (float): dropout value passed onto scaled_dot_product_attention.
Default: 0.0
intermediate_dim (Optional[int]): intermediate dimension for MLP. If not specified,
this is computed using :func:`~torchtune.modules.scale_hidden_dim_for_mlp`
norm_eps (float): epsilon in RMS norms.

Returns:
TransformerDecoder: Instantiation of Llama2 model.
"""
head_dim = embed_dim // num_heads
num_kv_heads = num_kv_heads if num_kv_heads else num_heads

rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len)
self_attn = CausalSelfAttention(
embed_dim=embed_dim,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False),
k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False),
v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False),
output_proj=nn.Linear(embed_dim, embed_dim, bias=False),
pos_embeddings=rope,
kv_cache=None,
max_seq_len=max_seq_len,
attn_dropout=attn_dropout,
)
hidden_dim = intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim)
mlp = llama2_mlp(dim=embed_dim, hidden_dim=hidden_dim)
layer = TransformerDecoderLayer(
attn=self_attn,
mlp=mlp,
sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps),
mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps),
)
tok_embeddings = nn.Embedding(vocab_size, embed_dim)
output_proj = nn.Linear(embed_dim, num_classes, bias=False)
return TransformerDecoder(
tok_embeddings=tok_embeddings,
layer=layer,
num_layers=num_layers,
max_seq_len=max_seq_len,
num_heads=num_heads,
head_dim=head_dim,
norm=RMSNorm(embed_dim, eps=norm_eps),
output=output_proj,
)


def lora_llama2_classifier(
lora_attn_modules: List[LORA_ATTN_MODULES],
apply_lora_to_mlp: bool = False,
apply_lora_to_output: bool = False,
*,
# llama2 classifier args,
num_classes: int,
# llama2 args
vocab_size: int,
num_layers: int,
num_heads: int,
num_kv_heads: int,
embed_dim: int,
max_seq_len: int,
intermediate_dim: Optional[int] = None,
attn_dropout: float = 0.0,
norm_eps: float = 1e-5,
# LoRA args
lora_rank: int,
lora_alpha: float,
lora_dropout: float = 0.0,
# Quantization args
quantize_base: bool = False,
) -> TransformerDecoder:
"""
Return a version of Llama2 (an instance of :func:`~torchtune.modules.TransformerDecoder`)
with LoRA applied based on the passed in configuration.

Args:
lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers
LoRA should be applied to in each self-attention block. Options are
``{"q_proj", "k_proj", "v_proj", "output_proj"}``.
apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer.
Default: False
apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection.
Default: False
num_classes (int): number of classes for classification.
vocab_size (int): number of tokens in vocabulary.
num_layers (int): number of layers in the transformer decoder.
num_heads (int): number of query heads. For MHA this is also the
number of heads for key and value
num_kv_heads (int): number of key and value heads. User should ensure
`num_heads` % `num_kv_heads` == 0. For standard MHA set `num_kv_heads` == `num_heads`,
for GQA `num_kv_heads` < `num_heads`, and for MQA set `num_kv_heads` == 1.
embed_dim (int): embedding dimension for self-attention
max_seq_len (int): maximum sequence length the model will be run with, as used
by :func:`~torchtune.modules.KVCache`
attn_dropout (float): dropout value passed onto scaled_dot_product_attention.
Default: 0.0
intermediate_dim (Optional[int]): intermediate dimension for MLP. If not specified,
this is computed using :func:`~torchtune.modules.scale_hidden_dim_for_mlp`
norm_eps (float): epsilon in RMS norms.
lora_rank (int): rank of each low-rank approximation
lora_alpha (float): scaling factor for the low-rank approximation
lora_dropout (float): LoRA dropout probability. Default: 0.0
quantize_base: (bool): Whether to quantize base model weights or not. Only applied to base
weights within linear layers LoRA is applied to. The final output linear projection is not
supported for quantization currently.

Returns:
TransformerDecoder: Instantiation of Llama2 model with LoRA applied to
a subset of the attention projections in each layer.

"""

self_attn = lora_llama2_self_attention(
lora_modules=lora_attn_modules,
embed_dim=embed_dim,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
max_seq_len=max_seq_len,
attn_dropout=attn_dropout,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
quantize_base=quantize_base,
)

hidden_dim = intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim)
if apply_lora_to_mlp:
mlp = lora_llama2_mlp(
dim=embed_dim,
hidden_dim=hidden_dim,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
quantize_base=quantize_base,
lora_dropout=lora_dropout,
)
else:
mlp = llama2_mlp(dim=embed_dim, hidden_dim=hidden_dim)

layer = TransformerDecoderLayer(
attn=self_attn,
mlp=mlp,
sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps),
mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps),
)

tok_embeddings = nn.Embedding(vocab_size, embed_dim)

# TODO: quantize_base is not applied to final output_proj currently.
output_proj = (
LoRALinear(embed_dim, num_classes, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout)
if apply_lora_to_output
else nn.Linear(embed_dim, num_classes, bias=False)
)
model = TransformerDecoder(
tok_embeddings=tok_embeddings,
layer=layer,
num_layers=num_layers,
max_seq_len=max_seq_len,
num_heads=num_heads,
head_dim=(embed_dim // num_heads),
norm=RMSNorm(embed_dim, eps=norm_eps),
output=output_proj,
)

if quantize_base:
# For QLoRA, we reparametrize 4-bit tensors to higher precision, and offload to CPU on the fly
# so as to not increase peak memory
model._register_state_dict_hook(
partial(
reparametrize_as_dtype_state_dict_post_hook,
# TODO this is clowny, figure out a better way to get what precision the rest
# of the model is in
dtype=tok_embeddings.weight.dtype,
offload_to_cpu=True,
)
)

return model
Loading
Loading