diff --git a/torchtune/models/clip/_component_builders.py b/torchtune/models/clip/_component_builders.py index 2b05a26ed8..35c697f241 100644 --- a/torchtune/models/clip/_component_builders.py +++ b/torchtune/models/clip/_component_builders.py @@ -39,6 +39,8 @@ def clip_vision_encoder( activation: Callable = nn.SiLU, cls_output_dim: int = 512, attn_bias: bool = True, + rope_base: Optional[int] = None, + encoder_max_seq_len: Optional[int] = None, out_indices: Optional[List[int]] = None, output_cls_projection: bool = False, max_num_tiles: int = 4, @@ -67,6 +69,11 @@ def clip_vision_encoder( activation (Callable): The activation function to use in the MLP layer. cls_output_dim (int): The dimensionality of the output tensor from the CLS projection module. attn_bias (bool): Boolean for if to use bias in the attention module. Default True. + rope_base (Optional[int]): base for the rotary positional embeddings. CLIP does not include rope by default, + if a value is passed in then rope will be added to multihead attention. Default: None + encoder_max_seq_len (Optional[int]): maximum sequence length the encoder will be run with, as used + by :func:`~torchtune.modules.RotaryPositionalEmbeddings`. This is required if ``rope_base`` + is specified. Default: None. out_indices (Optional[List[int]]): The indices of hidden layers to return. If provided, it will return the intermediate results of the transformer layers before they go through a next layer. For example, ``out_indices=[0,3]`` will @@ -85,25 +92,42 @@ def clip_vision_encoder( Raises: AssertionError: If ``embed_dim`` is not divisible by ``num_heads``. """ - assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" + if embed_dim % num_heads != 0: + raise ValueError( + f"embed_dim must be divisible by num_heads, got {embed_dim} and {num_heads}" + ) + if rope_base is not None and encoder_max_seq_len is None: + raise ValueError( + "encoder_max_seq_len must be provided if rope_base is specified. " + "This is used to determine the maximum sequence length for the rotary positional embeddings." + ) + + head_dim = embed_dim // num_heads cls_projection = ( CLSProjection(embed_dim=embed_dim, cls_output_dim=cls_output_dim) if output_cls_projection else None ) + rope = ( + RotaryPositionalEmbeddings( + dim=head_dim, max_seq_len=encoder_max_seq_len, base=rope_base + ) + if rope_base is not None + else None + ) # transformer layer self_attn = MultiHeadAttention( embed_dim=embed_dim, num_heads=num_heads, num_kv_heads=num_heads, - head_dim=embed_dim // num_heads, + head_dim=head_dim, q_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias), k_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias), v_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias), output_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias), - pos_embeddings=None, + pos_embeddings=rope, attn_dropout=0.0, is_causal=False, ) diff --git a/torchtune/models/llama3/_component_builders.py b/torchtune/models/llama3/_component_builders.py index 49ea3ed764..0a1f72865f 100644 --- a/torchtune/models/llama3/_component_builders.py +++ b/torchtune/models/llama3/_component_builders.py @@ -12,9 +12,9 @@ from torchtune.models.llama3._model_utils import scale_hidden_dim_for_mlp from torchtune.modules import ( - MultiHeadAttention, FeedForward, FrozenNF4Linear, + MultiHeadAttention, RMSNorm, RotaryPositionalEmbeddings, TransformerDecoder, @@ -40,6 +40,7 @@ # ------------------ Vanilla Llama3 ------------------ + def llama3( vocab_size: int, num_layers: int, @@ -48,7 +49,7 @@ def llama3( embed_dim: int, max_seq_len: int, attn_dropout: float = 0.0, - rope_base: int = 500000.0, + rope_base: int = 500_000, intermediate_dim: Optional[int] = None, norm_eps: float = 1e-5, ) -> TransformerDecoder: @@ -72,6 +73,7 @@ def llama3( by :func:`~torchtune.modules.KVCache` attn_dropout (float): dropout value passed onto scaled_dot_product_attention. Default: 0.0 + rope_base (int): base for the rotary positional embeddings. Default: 500_000 intermediate_dim (Optional[int]): intermediate dimension for MLP. If not specified, this is computed using :func:`~torchtune.modules.scale_hidden_dim_for_mlp` norm_eps (float): epsilon in RMS norms. @@ -81,7 +83,9 @@ def llama3( """ head_dim = embed_dim // num_heads num_kv_heads = num_kv_heads if num_kv_heads else num_heads - rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) + rope = RotaryPositionalEmbeddings( + dim=head_dim, max_seq_len=max_seq_len, base=rope_base + ) self_attn = MultiHeadAttention( embed_dim=embed_dim, num_heads=num_heads, @@ -95,7 +99,9 @@ def llama3( max_seq_len=max_seq_len, attn_dropout=attn_dropout, ) - hidden_dim = intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim) + hidden_dim = ( + intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim) + ) mlp = llama3_mlp(dim=embed_dim, hidden_dim=hidden_dim) layer = TransformerSelfAttentionLayer( attn=self_attn, @@ -116,17 +122,29 @@ def llama3( output=output_proj, ) + def llama3_mlp(dim: int, hidden_dim: int, quantize_base: bool = False) -> FeedForward: """ Build the MLP layer associated with the Llama model. """ - gate_proj = nn.Linear(dim, hidden_dim, bias=False) if not quantize_base else FrozenNF4Linear(dim, hidden_dim, bias=False) - down_proj = nn.Linear(hidden_dim, dim, bias=False) if not quantize_base else FrozenNF4Linear(hidden_dim, dim, bias=False) - up_proj = nn.Linear(dim, hidden_dim, bias=False) if not quantize_base else FrozenNF4Linear(dim, hidden_dim, bias=False) + gate_proj = ( + nn.Linear(dim, hidden_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(dim, hidden_dim, bias=False) + ) + down_proj = ( + nn.Linear(hidden_dim, dim, bias=False) + if not quantize_base + else FrozenNF4Linear(hidden_dim, dim, bias=False) + ) + up_proj = ( + nn.Linear(dim, hidden_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(dim, hidden_dim, bias=False) + ) return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj) - # ------------------ LoRA Llama3 ------------------ @@ -211,7 +229,9 @@ def lora_llama3( use_dora=use_dora, ) - hidden_dim = intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim) + hidden_dim = ( + intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim) + ) if apply_lora_to_mlp: mlp = lora_llama3_mlp( dim=embed_dim, @@ -223,7 +243,9 @@ def lora_llama3( use_dora=use_dora, ) else: - mlp = llama3_mlp(dim=embed_dim, hidden_dim=hidden_dim, quantize_base=quantize_base) + mlp = llama3_mlp( + dim=embed_dim, hidden_dim=hidden_dim, quantize_base=quantize_base + ) layer = TransformerSelfAttentionLayer( attn=self_attn, @@ -237,7 +259,13 @@ def lora_llama3( # TODO: quantize_base is not applied to final output_proj currently. adapter_cls = DoRALinear if use_dora else LoRALinear output_proj = ( - adapter_cls(embed_dim, vocab_size, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout) + adapter_cls( + embed_dim, + vocab_size, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + ) if apply_lora_to_output else nn.Linear(embed_dim, vocab_size, bias=False) ) @@ -382,7 +410,9 @@ def lora_llama3_self_attention( else FrozenNF4Linear(embed_dim, embed_dim, bias=False) ) ) - rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) + rope = RotaryPositionalEmbeddings( + dim=head_dim, max_seq_len=max_seq_len, base=rope_base + ) self_attn = MultiHeadAttention( embed_dim=embed_dim, num_heads=num_heads,