-
Notifications
You must be signed in to change notification settings - Fork 482
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generalizing reward models #1160
Changes from all commits
bedce7e
415d32f
575b0de
e0223aa
e8e9745
0ab7748
570dc8f
12488b7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -96,9 +96,7 @@ def llama2( | |
max_seq_len=max_seq_len, | ||
attn_dropout=attn_dropout, | ||
) | ||
hidden_dim = ( | ||
intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim) | ||
) | ||
hidden_dim = intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim) | ||
mlp = llama2_mlp(dim=embed_dim, hidden_dim=hidden_dim) | ||
layer = TransformerDecoderLayer( | ||
attn=self_attn, | ||
|
@@ -208,9 +206,7 @@ def lora_llama2( | |
quantize_base=quantize_base, | ||
) | ||
|
||
hidden_dim = ( | ||
intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim) | ||
) | ||
hidden_dim = intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim) | ||
if apply_lora_to_mlp: | ||
mlp = lora_llama2_mlp( | ||
dim=embed_dim, | ||
|
@@ -312,9 +308,7 @@ def lora_llama2_self_attention( | |
ValueError: If lora_modules arg is an empty list | ||
""" | ||
if not lora_modules: | ||
raise ValueError( | ||
f"Must pass one or more of {LORA_ATTN_MODULES} as lora_modules" | ||
) | ||
raise ValueError(f"Must pass one or more of {LORA_ATTN_MODULES} as lora_modules") | ||
|
||
head_dim = embed_dim // num_heads | ||
num_kv_heads = num_kv_heads if num_kv_heads else num_heads | ||
|
@@ -422,3 +416,216 @@ def lora_llama2_mlp( | |
down_proj=down_proj, | ||
up_proj=up_proj, | ||
) | ||
|
||
|
||
# ------------------ Llama2 Classifier ------------------ | ||
|
||
|
||
def llama2_classifier( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I dont want to overcomplicate things, but is "classifier" the right word? My knowledge of RL is limited, but can the output of the LLM just be a scalar, like how "polite", "trustworthy", "helpful", "friendly" the model is? In this case, would classifier still be the right naming? I have seen it in other places too, so I guess this may be the standard. |
||
num_classes: int, | ||
*, | ||
vocab_size: int, | ||
num_layers: int, | ||
num_heads: int, | ||
num_kv_heads: int, | ||
embed_dim: int, | ||
max_seq_len: int, | ||
attn_dropout: float = 0.0, | ||
intermediate_dim: Optional[int] = None, | ||
norm_eps: float = 1e-5, | ||
) -> TransformerDecoder: | ||
""" | ||
Build a base Llama2 model with the final projection replaced with a classification layer. | ||
|
||
Args: | ||
num_classes (int): number of classes for classification. | ||
vocab_size (int): number of tokens in vocabulary. | ||
num_layers (int): number of layers in the transformer decoder. | ||
num_heads (int): number of query heads. For MHA this is also the | ||
number of heads for key and value | ||
num_kv_heads (int): number of key and value heads. If specified, | ||
user should ensure `num_heads` % `num_kv_heads` == 0. Default value is | ||
`None`, in which case this is the same as MHA | ||
embed_dim (int): embedding dimension for self-attention | ||
max_seq_len (int): maximum sequence length the model will be run with, as used | ||
by :func:`~torchtune.modules.KVCache` | ||
attn_dropout (float): dropout value passed onto scaled_dot_product_attention. | ||
Default: 0.0 | ||
intermediate_dim (Optional[int]): intermediate dimension for MLP. If not specified, | ||
this is computed using :func:`~torchtune.modules.scale_hidden_dim_for_mlp` | ||
norm_eps (float): epsilon in RMS norms. | ||
|
||
Returns: | ||
TransformerDecoder: Instantiation of Llama2 model. | ||
""" | ||
head_dim = embed_dim // num_heads | ||
num_kv_heads = num_kv_heads if num_kv_heads else num_heads | ||
|
||
rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len) | ||
self_attn = CausalSelfAttention( | ||
embed_dim=embed_dim, | ||
num_heads=num_heads, | ||
num_kv_heads=num_kv_heads, | ||
head_dim=head_dim, | ||
q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), | ||
k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), | ||
v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), | ||
output_proj=nn.Linear(embed_dim, embed_dim, bias=False), | ||
pos_embeddings=rope, | ||
kv_cache=None, | ||
max_seq_len=max_seq_len, | ||
attn_dropout=attn_dropout, | ||
) | ||
hidden_dim = intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim) | ||
mlp = llama2_mlp(dim=embed_dim, hidden_dim=hidden_dim) | ||
layer = TransformerDecoderLayer( | ||
attn=self_attn, | ||
mlp=mlp, | ||
sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), | ||
mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), | ||
) | ||
tok_embeddings = nn.Embedding(vocab_size, embed_dim) | ||
output_proj = nn.Linear(embed_dim, num_classes, bias=False) | ||
return TransformerDecoder( | ||
tok_embeddings=tok_embeddings, | ||
layer=layer, | ||
num_layers=num_layers, | ||
max_seq_len=max_seq_len, | ||
num_heads=num_heads, | ||
head_dim=head_dim, | ||
norm=RMSNorm(embed_dim, eps=norm_eps), | ||
output=output_proj, | ||
) | ||
|
||
|
||
def lora_llama2_classifier( | ||
lora_attn_modules: List[LORA_ATTN_MODULES], | ||
apply_lora_to_mlp: bool = False, | ||
apply_lora_to_output: bool = False, | ||
*, | ||
# llama2 classifier args, | ||
num_classes: int, | ||
# llama2 args | ||
vocab_size: int, | ||
num_layers: int, | ||
num_heads: int, | ||
num_kv_heads: int, | ||
embed_dim: int, | ||
max_seq_len: int, | ||
intermediate_dim: Optional[int] = None, | ||
attn_dropout: float = 0.0, | ||
norm_eps: float = 1e-5, | ||
# LoRA args | ||
lora_rank: int, | ||
lora_alpha: float, | ||
lora_dropout: float = 0.0, | ||
# Quantization args | ||
quantize_base: bool = False, | ||
) -> TransformerDecoder: | ||
""" | ||
Return a version of Llama2 (an instance of :func:`~torchtune.modules.TransformerDecoder`) | ||
with LoRA applied based on the passed in configuration. | ||
|
||
Args: | ||
lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers | ||
LoRA should be applied to in each self-attention block. Options are | ||
``{"q_proj", "k_proj", "v_proj", "output_proj"}``. | ||
apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. | ||
Default: False | ||
apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. | ||
Default: False | ||
num_classes (int): number of classes for classification. | ||
vocab_size (int): number of tokens in vocabulary. | ||
num_layers (int): number of layers in the transformer decoder. | ||
num_heads (int): number of query heads. For MHA this is also the | ||
number of heads for key and value | ||
num_kv_heads (int): number of key and value heads. User should ensure | ||
`num_heads` % `num_kv_heads` == 0. For standard MHA set `num_kv_heads` == `num_heads`, | ||
for GQA `num_kv_heads` < `num_heads`, and for MQA set `num_kv_heads` == 1. | ||
embed_dim (int): embedding dimension for self-attention | ||
max_seq_len (int): maximum sequence length the model will be run with, as used | ||
by :func:`~torchtune.modules.KVCache` | ||
attn_dropout (float): dropout value passed onto scaled_dot_product_attention. | ||
Default: 0.0 | ||
intermediate_dim (Optional[int]): intermediate dimension for MLP. If not specified, | ||
this is computed using :func:`~torchtune.modules.scale_hidden_dim_for_mlp` | ||
norm_eps (float): epsilon in RMS norms. | ||
lora_rank (int): rank of each low-rank approximation | ||
lora_alpha (float): scaling factor for the low-rank approximation | ||
lora_dropout (float): LoRA dropout probability. Default: 0.0 | ||
quantize_base: (bool): Whether to quantize base model weights or not. Only applied to base | ||
weights within linear layers LoRA is applied to. The final output linear projection is not | ||
supported for quantization currently. | ||
|
||
Returns: | ||
TransformerDecoder: Instantiation of Llama2 model with LoRA applied to | ||
a subset of the attention projections in each layer. | ||
|
||
""" | ||
|
||
self_attn = lora_llama2_self_attention( | ||
lora_modules=lora_attn_modules, | ||
embed_dim=embed_dim, | ||
num_heads=num_heads, | ||
num_kv_heads=num_kv_heads, | ||
max_seq_len=max_seq_len, | ||
attn_dropout=attn_dropout, | ||
lora_rank=lora_rank, | ||
lora_alpha=lora_alpha, | ||
lora_dropout=lora_dropout, | ||
quantize_base=quantize_base, | ||
) | ||
|
||
hidden_dim = intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim) | ||
if apply_lora_to_mlp: | ||
mlp = lora_llama2_mlp( | ||
dim=embed_dim, | ||
hidden_dim=hidden_dim, | ||
lora_rank=lora_rank, | ||
lora_alpha=lora_alpha, | ||
quantize_base=quantize_base, | ||
lora_dropout=lora_dropout, | ||
) | ||
else: | ||
mlp = llama2_mlp(dim=embed_dim, hidden_dim=hidden_dim) | ||
|
||
layer = TransformerDecoderLayer( | ||
attn=self_attn, | ||
mlp=mlp, | ||
sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), | ||
mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), | ||
) | ||
|
||
tok_embeddings = nn.Embedding(vocab_size, embed_dim) | ||
|
||
# TODO: quantize_base is not applied to final output_proj currently. | ||
output_proj = ( | ||
LoRALinear(embed_dim, num_classes, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout) | ||
if apply_lora_to_output | ||
else nn.Linear(embed_dim, num_classes, bias=False) | ||
) | ||
model = TransformerDecoder( | ||
tok_embeddings=tok_embeddings, | ||
layer=layer, | ||
num_layers=num_layers, | ||
max_seq_len=max_seq_len, | ||
num_heads=num_heads, | ||
head_dim=(embed_dim // num_heads), | ||
norm=RMSNorm(embed_dim, eps=norm_eps), | ||
output=output_proj, | ||
) | ||
|
||
if quantize_base: | ||
# For QLoRA, we reparametrize 4-bit tensors to higher precision, and offload to CPU on the fly | ||
# so as to not increase peak memory | ||
model._register_state_dict_hook( | ||
partial( | ||
reparametrize_as_dtype_state_dict_post_hook, | ||
# TODO this is clowny, figure out a better way to get what precision the rest | ||
# of the model is in | ||
dtype=tok_embeddings.weight.dtype, | ||
offload_to_cpu=True, | ||
) | ||
) | ||
|
||
return model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seeing this huge list scares me a little bit, because every model size would have 6 variants: normal, classifier, lora, qlora, classifier_lora, classifier_qlora. If we add DoRA, this would have 3 more builders. I think you followed the pattern correctly. Maybe a question for @kartikayk and @ebsmothers. I guess we dont like hooks in torchtune, but something like replace_with_reward_head(model=llama3) seems convenient, instead of rebuilding llama3 completely.