Skip to content

Commit

Permalink
expose rope for clip
Browse files Browse the repository at this point in the history
  • Loading branch information
RdoubleA committed Nov 8, 2024
1 parent 96d649d commit a36dde5
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 15 deletions.
30 changes: 27 additions & 3 deletions torchtune/models/clip/_component_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def clip_vision_encoder(
activation: Callable = nn.SiLU,
cls_output_dim: int = 512,
attn_bias: bool = True,
rope_base: Optional[int] = None,
encoder_max_seq_len: Optional[int] = None,
out_indices: Optional[List[int]] = None,
output_cls_projection: bool = False,
max_num_tiles: int = 4,
Expand Down Expand Up @@ -67,6 +69,11 @@ def clip_vision_encoder(
activation (Callable): The activation function to use in the MLP layer.
cls_output_dim (int): The dimensionality of the output tensor from the CLS projection module.
attn_bias (bool): Boolean for if to use bias in the attention module. Default True.
rope_base (Optional[int]): base for the rotary positional embeddings. CLIP does not include rope by default,
if a value is passed in then rope will be added to multihead attention. Default: None
encoder_max_seq_len (Optional[int]): maximum sequence length the encoder will be run with, as used
by :func:`~torchtune.modules.RotaryPositionalEmbeddings`. This is required if ``rope_base``
is specified. Default: None.
out_indices (Optional[List[int]]): The indices of hidden layers to return.
If provided, it will return the intermediate results of the transformer layers
before they go through a next layer. For example, ``out_indices=[0,3]`` will
Expand All @@ -85,25 +92,42 @@ def clip_vision_encoder(
Raises:
AssertionError: If ``embed_dim`` is not divisible by ``num_heads``.
"""
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
if embed_dim % num_heads != 0:
raise ValueError(
f"embed_dim must be divisible by num_heads, got {embed_dim} and {num_heads}"
)
if rope_base is not None and encoder_max_seq_len is None:
raise ValueError(
"encoder_max_seq_len must be provided if rope_base is specified. "
"This is used to determine the maximum sequence length for the rotary positional embeddings."
)

head_dim = embed_dim // num_heads

cls_projection = (
CLSProjection(embed_dim=embed_dim, cls_output_dim=cls_output_dim)
if output_cls_projection
else None
)
rope = (
RotaryPositionalEmbeddings(
dim=head_dim, max_seq_len=encoder_max_seq_len, base=rope_base
)
if rope_base is not None
else None
)

# transformer layer
self_attn = MultiHeadAttention(
embed_dim=embed_dim,
num_heads=num_heads,
num_kv_heads=num_heads,
head_dim=embed_dim // num_heads,
head_dim=head_dim,
q_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias),
k_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias),
v_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias),
output_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias),
pos_embeddings=None,
pos_embeddings=rope,
attn_dropout=0.0,
is_causal=False,
)
Expand Down
54 changes: 42 additions & 12 deletions torchtune/models/llama3/_component_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
from torchtune.models.llama3._model_utils import scale_hidden_dim_for_mlp

from torchtune.modules import (
MultiHeadAttention,
FeedForward,
FrozenNF4Linear,
MultiHeadAttention,
RMSNorm,
RotaryPositionalEmbeddings,
TransformerDecoder,
Expand All @@ -40,6 +40,7 @@

# ------------------ Vanilla Llama3 ------------------


def llama3(
vocab_size: int,
num_layers: int,
Expand All @@ -48,7 +49,7 @@ def llama3(
embed_dim: int,
max_seq_len: int,
attn_dropout: float = 0.0,
rope_base: int = 500000.0,
rope_base: int = 500_000,
intermediate_dim: Optional[int] = None,
norm_eps: float = 1e-5,
) -> TransformerDecoder:
Expand All @@ -72,6 +73,7 @@ def llama3(
by :func:`~torchtune.modules.KVCache`
attn_dropout (float): dropout value passed onto scaled_dot_product_attention.
Default: 0.0
rope_base (int): base for the rotary positional embeddings. Default: 500_000
intermediate_dim (Optional[int]): intermediate dimension for MLP. If not specified,
this is computed using :func:`~torchtune.modules.scale_hidden_dim_for_mlp`
norm_eps (float): epsilon in RMS norms.
Expand All @@ -81,7 +83,9 @@ def llama3(
"""
head_dim = embed_dim // num_heads
num_kv_heads = num_kv_heads if num_kv_heads else num_heads
rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base)
rope = RotaryPositionalEmbeddings(
dim=head_dim, max_seq_len=max_seq_len, base=rope_base
)
self_attn = MultiHeadAttention(
embed_dim=embed_dim,
num_heads=num_heads,
Expand All @@ -95,7 +99,9 @@ def llama3(
max_seq_len=max_seq_len,
attn_dropout=attn_dropout,
)
hidden_dim = intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim)
hidden_dim = (
intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim)
)
mlp = llama3_mlp(dim=embed_dim, hidden_dim=hidden_dim)
layer = TransformerSelfAttentionLayer(
attn=self_attn,
Expand All @@ -116,17 +122,29 @@ def llama3(
output=output_proj,
)


def llama3_mlp(dim: int, hidden_dim: int, quantize_base: bool = False) -> FeedForward:
"""
Build the MLP layer associated with the Llama model.
"""
gate_proj = nn.Linear(dim, hidden_dim, bias=False) if not quantize_base else FrozenNF4Linear(dim, hidden_dim, bias=False)
down_proj = nn.Linear(hidden_dim, dim, bias=False) if not quantize_base else FrozenNF4Linear(hidden_dim, dim, bias=False)
up_proj = nn.Linear(dim, hidden_dim, bias=False) if not quantize_base else FrozenNF4Linear(dim, hidden_dim, bias=False)
gate_proj = (
nn.Linear(dim, hidden_dim, bias=False)
if not quantize_base
else FrozenNF4Linear(dim, hidden_dim, bias=False)
)
down_proj = (
nn.Linear(hidden_dim, dim, bias=False)
if not quantize_base
else FrozenNF4Linear(hidden_dim, dim, bias=False)
)
up_proj = (
nn.Linear(dim, hidden_dim, bias=False)
if not quantize_base
else FrozenNF4Linear(dim, hidden_dim, bias=False)
)
return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj)



# ------------------ LoRA Llama3 ------------------


Expand Down Expand Up @@ -211,7 +229,9 @@ def lora_llama3(
use_dora=use_dora,
)

hidden_dim = intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim)
hidden_dim = (
intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim)
)
if apply_lora_to_mlp:
mlp = lora_llama3_mlp(
dim=embed_dim,
Expand All @@ -223,7 +243,9 @@ def lora_llama3(
use_dora=use_dora,
)
else:
mlp = llama3_mlp(dim=embed_dim, hidden_dim=hidden_dim, quantize_base=quantize_base)
mlp = llama3_mlp(
dim=embed_dim, hidden_dim=hidden_dim, quantize_base=quantize_base
)

layer = TransformerSelfAttentionLayer(
attn=self_attn,
Expand All @@ -237,7 +259,13 @@ def lora_llama3(
# TODO: quantize_base is not applied to final output_proj currently.
adapter_cls = DoRALinear if use_dora else LoRALinear
output_proj = (
adapter_cls(embed_dim, vocab_size, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout)
adapter_cls(
embed_dim,
vocab_size,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
)
if apply_lora_to_output
else nn.Linear(embed_dim, vocab_size, bias=False)
)
Expand Down Expand Up @@ -382,7 +410,9 @@ def lora_llama3_self_attention(
else FrozenNF4Linear(embed_dim, embed_dim, bias=False)
)
)
rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base)
rope = RotaryPositionalEmbeddings(
dim=head_dim, max_seq_len=max_seq_len, base=rope_base
)
self_attn = MultiHeadAttention(
embed_dim=embed_dim,
num_heads=num_heads,
Expand Down

0 comments on commit a36dde5

Please sign in to comment.