From bce70917c3d0d1f7693c9ae8b59cd72ee55b659d Mon Sep 17 00:00:00 2001 From: Rafi Ayub <33648637+RdoubleA@users.noreply.github.com> Date: Sun, 17 Nov 2024 07:23:35 -0800 Subject: [PATCH] 2D RoPE + CLIP updates (#1973) --- .../phi3/test_phi3_position_embeddings.py | 60 +++++++ .../modules/test_position_embeddings.py | 130 +++++++++++----- .../modules/test_vision_transformer.py | 24 +++ tests/torchtune/training/test_distributed.py | 6 +- torchtune/models/clip/_component_builders.py | 69 +++++--- .../models/llama3/_component_builders.py | 58 +++++-- torchtune/models/llama3/_model_builders.py | 46 ++++-- .../models/llama3_2_vision/_model_builders.py | 8 +- torchtune/modules/__init__.py | 6 +- torchtune/modules/position_embeddings.py | 147 +++++++++++++++++- torchtune/modules/vision_transformer.py | 18 ++- 11 files changed, 472 insertions(+), 100 deletions(-) create mode 100644 tests/torchtune/models/phi3/test_phi3_position_embeddings.py diff --git a/tests/torchtune/models/phi3/test_phi3_position_embeddings.py b/tests/torchtune/models/phi3/test_phi3_position_embeddings.py new file mode 100644 index 0000000000..487850a4ce --- /dev/null +++ b/tests/torchtune/models/phi3/test_phi3_position_embeddings.py @@ -0,0 +1,60 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch + +from tests.test_utils import assert_expected, mps_ignored_test +from torch import tensor +from torchtune.models.phi3 import Phi3RotaryPositionalEmbeddings + +from torchtune.training.seed import set_seed + + +@pytest.fixture(autouse=True) +def random(): + set_seed(0) + + +class TestPhi3RotaryPositionalEmbeddings: + """ + Class for testing the Phi3 models RoPE Embeddings. The expected tensors are + computed from the reference implementation here: + https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/modeling_phi3.py + """ + + @pytest.fixture + def input_params(self): + bsz = 4 + num_heads = 32 + embed_dim = 3072 + seq_len = 60 + max_seq_len = 4096 + head_dim = embed_dim // num_heads + return bsz, num_heads, head_dim, seq_len, max_seq_len + + @pytest.fixture + def input(self, input_params) -> tensor: + bsz, num_heads, head_dim, seq_len, _ = input_params + return torch.randn(bsz, seq_len, num_heads, head_dim) + + @pytest.fixture + def rope_phi3(self, input_params) -> Phi3RotaryPositionalEmbeddings: + _, _, head_dim, _, max_seq_len = input_params + return Phi3RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len) + + @mps_ignored_test() + def test_forward( + self, input: tensor, rope_phi3: Phi3RotaryPositionalEmbeddings + ) -> None: + x_out = rope_phi3(input) + + # check the numerics of the computed tensor + assert_expected(x_out.mean(), tensor(-0.0005), atol=1e-4) + assert_expected(x_out.sum(), tensor(-381.0620)) + + # check shapes + assert_expected(x_out.shape, input.shape) diff --git a/tests/torchtune/modules/test_position_embeddings.py b/tests/torchtune/modules/test_position_embeddings.py index 2f0dcb9a4e..282fce085c 100644 --- a/tests/torchtune/modules/test_position_embeddings.py +++ b/tests/torchtune/modules/test_position_embeddings.py @@ -4,16 +4,16 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple - import pytest import torch from tests.test_utils import assert_expected, mps_ignored_test from torch import tensor -from torchtune.models.phi3 import Phi3RotaryPositionalEmbeddings -from torchtune.modules.position_embeddings import RotaryPositionalEmbeddings +from torchtune.modules.position_embeddings import ( + RotaryPositionalEmbeddings, + VisionRotaryPositionalEmbeddings, +) from torchtune.training.seed import set_seed @@ -35,7 +35,7 @@ class TestRotaryPositionEmbedding: EXPECTED_X_OUT_MAX = tensor(5.4546) @pytest.fixture - def input_params(self) -> Tuple[int, int, int, int]: + def input_params(self): bsz = 4 num_heads = 32 embed_dim = 4096 @@ -45,14 +45,12 @@ def input_params(self) -> Tuple[int, int, int, int]: return bsz, num_heads, head_dim, seq_len, max_seq_len @pytest.fixture - def input(self, input_params: Tuple[int, int, int, int]) -> tensor: + def input(self, input_params) -> tensor: bsz, num_heads, head_dim, seq_len, _ = input_params return torch.randn(bsz, seq_len, num_heads, head_dim) @pytest.fixture - def rope( - self, input_params: Tuple[int, int, int, int] - ) -> RotaryPositionalEmbeddings: + def rope(self, input_params) -> RotaryPositionalEmbeddings: _, _, head_dim, _, max_seq_len = input_params return RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len) @@ -136,44 +134,106 @@ def test_rope_init_meta_device(self, input_params): torch.testing.assert_close(p1, p2) -class TestPhi3RotaryPositionalEmbeddings: - """ - Class for testing the Phi3 models RoPE Embeddings. The expected tensors are - computed from the reference implementation here: - https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/modeling_phi3.py - """ +class TestVisionRotaryPositionEmbedding: + + EXPECTED_X_OUT_MEAN = tensor(0.0789793) + EXPECTED_X_OUT_SUM = tensor(25.2733822) + EXPECTED_X_OUT_MAX = tensor(3.1225626) @pytest.fixture - def input_params(self) -> Tuple[int, int, int, int]: - bsz = 4 - num_heads = 32 - embed_dim = 3072 - seq_len = 60 - max_seq_len = 4096 + def input_params(self): + bsz = 2 + num_heads = 8 + embed_dim = 32 head_dim = embed_dim // num_heads - return bsz, num_heads, head_dim, seq_len, max_seq_len + seq_len = 5 + patch_size = 4 + tile_size = 16 + return bsz, num_heads, head_dim, seq_len, patch_size, tile_size @pytest.fixture - def input(self, input_params: Tuple[int, int, int, int]) -> tensor: - bsz, num_heads, head_dim, seq_len, _ = input_params + def input(self, input_params) -> tensor: + bsz, num_heads, head_dim, seq_len, *_ = input_params return torch.randn(bsz, seq_len, num_heads, head_dim) @pytest.fixture - def rope_phi3( - self, input_params: Tuple[int, int, int, int] - ) -> Phi3RotaryPositionalEmbeddings: - _, _, head_dim, _, max_seq_len = input_params - return Phi3RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len) + def rope(self, input_params): + _, _, head_dim, _, patch_size, tile_size = input_params + return VisionRotaryPositionalEmbeddings( + patch_size=patch_size, tile_size=tile_size, dim=head_dim // 2 + ) @mps_ignored_test() - def test_forward( - self, input: tensor, rope_phi3: Phi3RotaryPositionalEmbeddings - ) -> None: - x_out = rope_phi3(input) + def test_forward(self, input, rope) -> None: + x_out = rope(input) # check the numerics of the computed tensor - assert_expected(x_out.mean(), tensor(-0.0005), atol=1e-4) - assert_expected(x_out.sum(), tensor(-381.0620)) + assert_expected(x_out.mean(), self.EXPECTED_X_OUT_MEAN) + assert_expected(x_out.sum(), self.EXPECTED_X_OUT_SUM) + assert_expected(x_out.max(), self.EXPECTED_X_OUT_MAX) + + # check shapes + assert_expected(x_out.shape, input.shape) + + @mps_ignored_test() + def test_forward_with_curr_pos(self, input, rope) -> None: + ( + _, + seq_len, + _, + _, + ) = input.shape + x_out = rope(input, input_pos=torch.arange(seq_len)) + + # these values should be exactly the same as test_forward + # since in this case input_pos covers the entire input + # sequence. This tests that input_pos works as expected i.e. + # extracts the embeddings for the relevant positions + assert_expected(x_out.mean(), self.EXPECTED_X_OUT_MEAN, atol=1e-4) + assert_expected(x_out.sum(), self.EXPECTED_X_OUT_SUM) + assert_expected(x_out.max(), self.EXPECTED_X_OUT_MAX) # check shapes assert_expected(x_out.shape, input.shape) + + @mps_ignored_test() + def test_forward_with_packed_pos(self, input, rope) -> None: + """ + Use input_pos to indicate positions of each token relative to its sequence + when sample is packed. + """ + ( + bsz, + seq_len, + _, + _, + ) = input.shape + x_out = rope( + input, input_pos=torch.arange(seq_len).unsqueeze(0).expand(bsz, seq_len) + ) + + # these values should be exactly the same as test_forward + # AND test_forward_with_current_pos. In this case input_pos + # covers the entire batch dim and is defined for each sample separately. + # This tests that input_pos works as expected i.e. + # extracts the embeddings for the relevant positions for each sample + assert_expected(x_out.mean(), self.EXPECTED_X_OUT_MEAN, atol=1e-4) + assert_expected(x_out.sum(), self.EXPECTED_X_OUT_SUM) + assert_expected(x_out.max(), self.EXPECTED_X_OUT_MAX) + + # check shapes + assert_expected(x_out.shape, input.shape) + + def test_rope_init_meta_device(self, input_params): + _, _, head_dim, _, patch_size, tile_size = input_params + rope_on_device = VisionRotaryPositionalEmbeddings( + dim=head_dim, patch_size=patch_size, tile_size=tile_size + ) + with torch.device("meta"): + meta_rope = VisionRotaryPositionalEmbeddings( + dim=head_dim, patch_size=patch_size, tile_size=tile_size + ) + + meta_rope.rope_init() + for p1, p2 in zip(rope_on_device.buffers(), meta_rope.buffers()): + torch.testing.assert_close(p1, p2) diff --git a/tests/torchtune/modules/test_vision_transformer.py b/tests/torchtune/modules/test_vision_transformer.py index 23bacad22f..4f4330c5b0 100644 --- a/tests/torchtune/modules/test_vision_transformer.py +++ b/tests/torchtune/modules/test_vision_transformer.py @@ -210,3 +210,27 @@ def test_vision_transformer_single_tile(self, transformer_config): ), f"Expected shape {expected_shape}, but got {output.shape}" assert_expected(output.mean(), torch.tensor(0.5458), atol=1e-3, rtol=1e-3) + + @torch.no_grad() + def test_vision_transformer_append_cls_token(self, transformer_config): + transformer_config = transformer_config.copy() + transformer_config["append_cls_token"] = True + + model_append_cls = clip_vision_encoder(**transformer_config).eval() + fixed_init_model(model_append_cls, min_val=-1, max_val=1) + output, _ = model_append_cls(self.image, self.aspect_ratio) + + # assertion + expected_shape = ( + self.batch_size, + self.n_imgs, + self.num_tiles, + model_append_cls.get_image_tokens_per_tile(), + transformer_config["embed_dim"], + ) + + assert ( + output.shape == expected_shape + ), f"Expected shape {expected_shape}, but got {output.shape}" + + assert_expected(output.mean(), torch.tensor(1.0172), atol=1e-3, rtol=1e-3) diff --git a/tests/torchtune/training/test_distributed.py b/tests/torchtune/training/test_distributed.py index 334a67487e..638e7799a3 100644 --- a/tests/torchtune/training/test_distributed.py +++ b/tests/torchtune/training/test_distributed.py @@ -360,8 +360,10 @@ def _test_qlora_state_dict(self, enable_activation_checkpointing: bool): ) # init rope since it's not covered in state dict for m in fsdp_model_to_load.modules(): - if isinstance(m, modules.RotaryPositionalEmbeddings): - m.reset_parameters() + if isinstance(m, modules.RotaryPositionalEmbeddings) or isinstance( + m, modules.VisionRotaryPositionalEmbeddings + ): + m.rope_init() for m in fsdp_model_to_load.modules(): if enable_activation_checkpointing: if isinstance(m, CheckpointWrapper): diff --git a/torchtune/models/clip/_component_builders.py b/torchtune/models/clip/_component_builders.py index c1b694342c..4bbe0ab7f7 100644 --- a/torchtune/models/clip/_component_builders.py +++ b/torchtune/models/clip/_component_builders.py @@ -7,7 +7,6 @@ from functools import partial from typing import Callable, List, Optional -import torch from torch import nn from torchtune.models.clip._position_embeddings import ( TiledTokenPositionalEmbedding, @@ -21,6 +20,7 @@ FrozenNF4Linear, MultiHeadAttention, TransformerSelfAttentionLayer, + VisionRotaryPositionalEmbeddings, ) from torchtune.modules.common_utils import reparametrize_as_dtype_state_dict_post_hook @@ -39,11 +39,12 @@ def clip_vision_encoder( activation: Callable = nn.SiLU, cls_output_dim: int = 512, attn_bias: bool = True, + use_rope: bool = False, out_indices: Optional[List[int]] = None, output_cls_projection: bool = False, max_num_tiles: int = 4, in_channels: int = 3, - intermediate_act: torch.nn.Module = torch.nn.SiLU(), + append_cls_token: bool = False, ) -> VisionTransformer: """ Builds the vision encoder associated with the clip model. This includes: @@ -67,6 +68,7 @@ def clip_vision_encoder( activation (Callable): The activation function to use in the MLP layer. cls_output_dim (int): The dimensionality of the output tensor from the CLS projection module. attn_bias (bool): Boolean for if to use bias in the attention module. Default True. + use_rope (bool): If True, include 2D rope in attention in each transformer layer. Default: False out_indices (Optional[List[int]]): The indices of hidden layers to return. If provided, it will return the intermediate results of the transformer layers before they go through a next layer. For example, ``out_indices=[0,3]`` will @@ -76,7 +78,8 @@ def clip_vision_encoder( max_num_tiles (int): The maximum number of tiles that can be processed. This is used to determine the size of the positional embeddings. in_channels (int): The number of image input channels. - intermediate_act (torch.nn.Module): The activation function used in the intermediate layers in the transformer encoder. + append_cls_token (bool): If True, adds CLS token embedding to the end of the sequence in the vision transformer. + Default is False, which adds CLS token to the beginning of the sequence. Returns: A `VisionTransformer` object. @@ -84,25 +87,45 @@ def clip_vision_encoder( Raises: AssertionError: If ``embed_dim`` is not divisible by ``num_heads``. """ - assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" + if embed_dim % num_heads != 0: + raise ValueError( + f"embed_dim must be divisible by num_heads, got {embed_dim} and {num_heads}" + ) + if use_rope and max_num_tiles != 1: + raise ValueError( + f"2D RoPE is only supported for max_num_tiles = 1, got {max_num_tiles}" + ) + + head_dim = embed_dim // num_heads cls_projection = ( CLSProjection(embed_dim=embed_dim, cls_output_dim=cls_output_dim) if output_cls_projection else None ) + rope = ( + VisionRotaryPositionalEmbeddings( + patch_size=patch_size, + tile_size=tile_size, + dim=head_dim // 2, + base=10_000, + append_cls_token=append_cls_token, + ) + if use_rope + else None + ) # transformer layer self_attn = MultiHeadAttention( embed_dim=embed_dim, num_heads=num_heads, num_kv_heads=num_heads, - head_dim=embed_dim // num_heads, + head_dim=head_dim, q_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias), k_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias), v_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias), output_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias), - pos_embeddings=None, + pos_embeddings=rope, attn_dropout=0.0, is_causal=False, ) @@ -154,6 +177,7 @@ def clip_vision_encoder( patch_size=patch_size, embed_dim=embed_dim, in_channels=in_channels, + append_cls_token=append_cls_token, ) @@ -189,7 +213,6 @@ def clip_mlp( def lora_clip_vision_encoder( lora_modules: List[LORA_ATTN_MODULES], apply_lora_to_mlp: bool = False, - apply_lora_to_output: bool = False, *, # clip encoder parameters tile_size: int, @@ -199,12 +222,11 @@ def lora_clip_vision_encoder( num_heads: int, activation: Callable = nn.SiLU, cls_output_dim: int = 512, - attn_bias: bool = True, + attn_bias: bool = False, out_indices: Optional[List[int]] = None, output_cls_projection: bool = False, max_num_tiles: int = 4, in_channels: int = 3, - intermediate_act: torch.nn.Module = torch.nn.SiLU(), # LoRA parameters lora_rank: int = 8, lora_alpha: float = 16, @@ -222,8 +244,6 @@ def lora_clip_vision_encoder( ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. Default: False - apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. - Default: False tile_size (int): The size of your image tiles, if the image was tile-cropped in advance. Otherwise, the size of the input image. In this case, the function will consider your image as a single tile. patch_size (int): The size of each patch. Used to divide the tiles into patches. @@ -234,7 +254,7 @@ def lora_clip_vision_encoder( num_heads (int): The number of attention heads in each transformer layer. activation (Callable): The activation function to use in the MLP layer. cls_output_dim (int): The dimensionality of the output tensor from the CLS projection module. - attn_bias (bool): Boolean for if to use bias in the attention module. Default True. + attn_bias (bool): Boolean for if to use bias in the attention module. Default False. out_indices (Optional[List[int]]): The indices of hidden layers to return. If provided, it will return the intermediate results of the transformer layers before they go through a next layer. For example, ``out_indices=[0,3]`` will @@ -244,7 +264,6 @@ def lora_clip_vision_encoder( max_num_tiles (int): The maximum number of tiles that can be processed. This is used to determine the size of the positional embeddings. in_channels (int): The number of image input channels. - intermediate_act (torch.nn.Module): The activation function used in the intermediate layers in the transformer encoder. lora_rank (int): rank of each low-rank approximation lora_alpha (float): scaling factor for the low-rank approximation lora_dropout (float): LoRA dropout probability. Default: 0.0 @@ -279,6 +298,7 @@ def lora_clip_vision_encoder( lora_dropout=lora_dropout, use_dora=use_dora, quantize_base=quantize_base, + attn_bias=attn_bias, **quantization_kwargs, ) if apply_lora_to_mlp: @@ -366,6 +386,7 @@ def lora_clip_attention( num_heads: int, num_kv_heads: int, attn_dropout: float = 0.0, + attn_bias: bool = False, # LoRA args lora_rank: int, lora_alpha: float, @@ -424,10 +445,10 @@ def lora_clip_attention( ) if "q_proj" in lora_modules else ( - nn.Linear(embed_dim, num_heads * head_dim, bias=False) + nn.Linear(embed_dim, num_heads * head_dim, bias=attn_bias) if not quantize_base else FrozenNF4Linear( - embed_dim, num_heads * head_dim, bias=False, **quantization_kwargs + embed_dim, num_heads * head_dim, bias=attn_bias, **quantization_kwargs ) ) ) @@ -443,10 +464,13 @@ def lora_clip_attention( ) if "k_proj" in lora_modules else ( - nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + nn.Linear(embed_dim, num_kv_heads * head_dim, bias=attn_bias) if not quantize_base else FrozenNF4Linear( - embed_dim, num_kv_heads * head_dim, bias=False, **quantization_kwargs + embed_dim, + num_kv_heads * head_dim, + bias=attn_bias, + **quantization_kwargs, ) ) ) @@ -462,10 +486,13 @@ def lora_clip_attention( ) if "v_proj" in lora_modules else ( - nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + nn.Linear(embed_dim, num_kv_heads * head_dim, bias=attn_bias) if not quantize_base else FrozenNF4Linear( - embed_dim, num_kv_heads * head_dim, bias=False, **quantization_kwargs + embed_dim, + num_kv_heads * head_dim, + bias=attn_bias, + **quantization_kwargs, ) ) ) @@ -481,10 +508,10 @@ def lora_clip_attention( ) if "output_proj" in lora_modules else ( - nn.Linear(embed_dim, embed_dim, bias=False) + nn.Linear(embed_dim, embed_dim, bias=attn_bias) if not quantize_base else FrozenNF4Linear( - embed_dim, embed_dim, bias=False, **quantization_kwargs + embed_dim, embed_dim, bias=attn_bias, **quantization_kwargs ) ) ) diff --git a/torchtune/models/llama3/_component_builders.py b/torchtune/models/llama3/_component_builders.py index 49ea3ed764..ca3c1c34bc 100644 --- a/torchtune/models/llama3/_component_builders.py +++ b/torchtune/models/llama3/_component_builders.py @@ -12,9 +12,9 @@ from torchtune.models.llama3._model_utils import scale_hidden_dim_for_mlp from torchtune.modules import ( - MultiHeadAttention, FeedForward, FrozenNF4Linear, + MultiHeadAttention, RMSNorm, RotaryPositionalEmbeddings, TransformerDecoder, @@ -40,6 +40,7 @@ # ------------------ Vanilla Llama3 ------------------ + def llama3( vocab_size: int, num_layers: int, @@ -48,7 +49,7 @@ def llama3( embed_dim: int, max_seq_len: int, attn_dropout: float = 0.0, - rope_base: int = 500000.0, + rope_base: int = 500_000, intermediate_dim: Optional[int] = None, norm_eps: float = 1e-5, ) -> TransformerDecoder: @@ -72,6 +73,7 @@ def llama3( by :func:`~torchtune.modules.KVCache` attn_dropout (float): dropout value passed onto scaled_dot_product_attention. Default: 0.0 + rope_base (int): base for the rotary positional embeddings. Default: 500_000 intermediate_dim (Optional[int]): intermediate dimension for MLP. If not specified, this is computed using :func:`~torchtune.modules.scale_hidden_dim_for_mlp` norm_eps (float): epsilon in RMS norms. @@ -81,7 +83,9 @@ def llama3( """ head_dim = embed_dim // num_heads num_kv_heads = num_kv_heads if num_kv_heads else num_heads - rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) + rope = RotaryPositionalEmbeddings( + dim=head_dim, max_seq_len=max_seq_len, base=rope_base + ) self_attn = MultiHeadAttention( embed_dim=embed_dim, num_heads=num_heads, @@ -95,7 +99,9 @@ def llama3( max_seq_len=max_seq_len, attn_dropout=attn_dropout, ) - hidden_dim = intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim) + hidden_dim = ( + intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim) + ) mlp = llama3_mlp(dim=embed_dim, hidden_dim=hidden_dim) layer = TransformerSelfAttentionLayer( attn=self_attn, @@ -116,17 +122,29 @@ def llama3( output=output_proj, ) + def llama3_mlp(dim: int, hidden_dim: int, quantize_base: bool = False) -> FeedForward: """ Build the MLP layer associated with the Llama model. """ - gate_proj = nn.Linear(dim, hidden_dim, bias=False) if not quantize_base else FrozenNF4Linear(dim, hidden_dim, bias=False) - down_proj = nn.Linear(hidden_dim, dim, bias=False) if not quantize_base else FrozenNF4Linear(hidden_dim, dim, bias=False) - up_proj = nn.Linear(dim, hidden_dim, bias=False) if not quantize_base else FrozenNF4Linear(dim, hidden_dim, bias=False) + gate_proj = ( + nn.Linear(dim, hidden_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(dim, hidden_dim, bias=False) + ) + down_proj = ( + nn.Linear(hidden_dim, dim, bias=False) + if not quantize_base + else FrozenNF4Linear(hidden_dim, dim, bias=False) + ) + up_proj = ( + nn.Linear(dim, hidden_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(dim, hidden_dim, bias=False) + ) return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj) - # ------------------ LoRA Llama3 ------------------ @@ -145,7 +163,7 @@ def lora_llama3( intermediate_dim: Optional[int] = None, attn_dropout: float = 0.0, norm_eps: float = 1e-5, - rope_base: float = 500000.0, + rope_base: int = 500_000, # LoRA args lora_rank: int, lora_alpha: float, @@ -211,7 +229,9 @@ def lora_llama3( use_dora=use_dora, ) - hidden_dim = intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim) + hidden_dim = ( + intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim) + ) if apply_lora_to_mlp: mlp = lora_llama3_mlp( dim=embed_dim, @@ -223,7 +243,9 @@ def lora_llama3( use_dora=use_dora, ) else: - mlp = llama3_mlp(dim=embed_dim, hidden_dim=hidden_dim, quantize_base=quantize_base) + mlp = llama3_mlp( + dim=embed_dim, hidden_dim=hidden_dim, quantize_base=quantize_base + ) layer = TransformerSelfAttentionLayer( attn=self_attn, @@ -237,7 +259,13 @@ def lora_llama3( # TODO: quantize_base is not applied to final output_proj currently. adapter_cls = DoRALinear if use_dora else LoRALinear output_proj = ( - adapter_cls(embed_dim, vocab_size, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout) + adapter_cls( + embed_dim, + vocab_size, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + ) if apply_lora_to_output else nn.Linear(embed_dim, vocab_size, bias=False) ) @@ -269,7 +297,7 @@ def lora_llama3_self_attention( num_kv_heads: int, max_seq_len: int, attn_dropout: float = 0.0, - rope_base: float = 500000.0, + rope_base: int = 500_000, # LoRA args lora_rank: int, lora_alpha: float, @@ -382,7 +410,9 @@ def lora_llama3_self_attention( else FrozenNF4Linear(embed_dim, embed_dim, bias=False) ) ) - rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) + rope = RotaryPositionalEmbeddings( + dim=head_dim, max_seq_len=max_seq_len, base=rope_base + ) self_attn = MultiHeadAttention( embed_dim=embed_dim, num_heads=num_heads, diff --git a/torchtune/models/llama3/_model_builders.py b/torchtune/models/llama3/_model_builders.py index cf4525824e..0ddca90189 100644 --- a/torchtune/models/llama3/_model_builders.py +++ b/torchtune/models/llama3/_model_builders.py @@ -3,17 +3,17 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import List, Optional from functools import partial +from typing import List, Optional + +from torchtune.data._prompt_templates import _get_prompt_template, _TemplateType from torchtune.models.llama3._component_builders import llama3, lora_llama3 +from torchtune.models.llama3._tokenizer import Llama3Tokenizer from torchtune.modules import TransformerDecoder -from torchtune.models.llama3._tokenizer import Llama3Tokenizer from torchtune.modules.peft import LORA_ATTN_MODULES from torchtune.modules.tokenizers import parse_hf_tokenizer_json -from torchtune.data._prompt_templates import _TemplateType -from torchtune.data._prompt_templates import _get_prompt_template """ @@ -40,7 +40,7 @@ def llama3_8b() -> TransformerDecoder: intermediate_dim=14336, attn_dropout=0.0, norm_eps=1e-5, - rope_base=500000.0, + rope_base=500_000, ) @@ -61,18 +61,23 @@ def llama3_70b() -> TransformerDecoder: intermediate_dim=28672, attn_dropout=0.0, norm_eps=1e-5, - rope_base=500000.0, + rope_base=500_000, ) - -def llama3_tokenizer(path: str, special_tokens_path: Optional[str] = None, max_seq_len: Optional[int] = None, prompt_template: Optional[_TemplateType] = None) -> Llama3Tokenizer: + +def llama3_tokenizer( + path: str, + special_tokens_path: Optional[str] = None, + max_seq_len: Optional[int] = None, + prompt_template: Optional[_TemplateType] = None, +) -> Llama3Tokenizer: """ Tokenizer for Llama3. Args: path (str): path to the tokenizer special_tokens_path (Optional[str]): Path to ``tokenizer.json`` from Hugging Face - model files that contains all registered special tokens, or a local json file + model files that contains all registered special tokens, or a local json file structured similarly. Default is None to use the canonical Llama3 special tokens. max_seq_len (Optional[int]): maximum sequence length for tokenizing a single list of messages, after which the input will be truncated. Default is None. @@ -80,13 +85,24 @@ def llama3_tokenizer(path: str, special_tokens_path: Optional[str] = None, max_s If a string, it is assumed to be the dotpath of a :class:`~torchtune.data.PromptTemplateInterface` class. If a dictionary, it is assumed to be a custom prompt template mapping role to the prepend/append tags. - + Returns: Llama3Tokenizer: Instantiation of the Llama3 tokenizer """ - special_tokens = parse_hf_tokenizer_json(special_tokens_path) if special_tokens_path is not None else None - template = _get_prompt_template(prompt_template) if prompt_template is not None else None - return Llama3Tokenizer(path=path, special_tokens=special_tokens, max_seq_len=max_seq_len, prompt_template=template) + special_tokens = ( + parse_hf_tokenizer_json(special_tokens_path) + if special_tokens_path is not None + else None + ) + template = ( + _get_prompt_template(prompt_template) if prompt_template is not None else None + ) + return Llama3Tokenizer( + path=path, + special_tokens=special_tokens, + max_seq_len=max_seq_len, + prompt_template=template, + ) def lora_llama3_8b( @@ -137,7 +153,7 @@ def lora_llama3_8b( intermediate_dim=14336, attn_dropout=0.0, norm_eps=1e-5, - rope_base=500000.0, + rope_base=500_000, lora_rank=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, @@ -194,7 +210,7 @@ def lora_llama3_70b( intermediate_dim=28672, attn_dropout=0.0, norm_eps=1e-5, - rope_base=500000.0, + rope_base=500_000, lora_rank=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, diff --git a/torchtune/models/llama3_2_vision/_model_builders.py b/torchtune/models/llama3_2_vision/_model_builders.py index 02e78b53f3..d13ff2dcc4 100644 --- a/torchtune/models/llama3_2_vision/_model_builders.py +++ b/torchtune/models/llama3_2_vision/_model_builders.py @@ -106,7 +106,7 @@ def llama3_2_vision_11b( embed_dim=4096, max_seq_len=131_072, encoder_max_seq_len=128_080, # 20*6404 - rope_base=500000.0, + rope_base=500_000, intermediate_dim=14336, ) return DeepFusionModel( @@ -207,7 +207,7 @@ def lora_llama3_2_vision_11b( embed_dim=4096, max_seq_len=131_072, encoder_max_seq_len=128_080, # 20*6404 - rope_base=500000.0, + rope_base=500_000, intermediate_dim=14336, lora_rank=lora_rank, lora_alpha=lora_alpha, @@ -264,7 +264,7 @@ def llama3_2_vision_90b( embed_dim=8192, max_seq_len=131_072, encoder_max_seq_len=128_080, # 20*6404 - rope_base=500000.0, + rope_base=500_000, intermediate_dim=28672, ) return DeepFusionModel( @@ -365,7 +365,7 @@ def lora_llama3_2_vision_90b( embed_dim=8192, max_seq_len=131_072, encoder_max_seq_len=128_080, # 20*6404 - rope_base=500000.0, + rope_base=500_000, intermediate_dim=28672, lora_rank=lora_rank, lora_alpha=lora_alpha, diff --git a/torchtune/modules/__init__.py b/torchtune/modules/__init__.py index c9547cf25b..8c4bed6e21 100644 --- a/torchtune/modules/__init__.py +++ b/torchtune/modules/__init__.py @@ -16,7 +16,10 @@ from .kv_cache import KVCache # noqa from .layer_norm import Fp32LayerNorm # noqa from .low_precision import FrozenNF4Linear # noqa -from .position_embeddings import RotaryPositionalEmbeddings # noqa +from .position_embeddings import ( # noqa + RotaryPositionalEmbeddings, + VisionRotaryPositionalEmbeddings, +) from .rms_norm import RMSNorm # noqa from .tanh_gate import TanhGate # noqa from .tied_linear import TiedLinear # noqa @@ -35,6 +38,7 @@ "FrozenNF4Linear", "KVCache", "RotaryPositionalEmbeddings", + "VisionRotaryPositionalEmbeddings", "RMSNorm", "TiedLinear", "Fp32LayerNorm", diff --git a/torchtune/modules/position_embeddings.py b/torchtune/modules/position_embeddings.py index a4f2036933..5f07772d82 100644 --- a/torchtune/modules/position_embeddings.py +++ b/torchtune/modules/position_embeddings.py @@ -43,11 +43,6 @@ def __init__( self.max_seq_len = max_seq_len self.rope_init() - # TODO: delete this once all our recipes are moved off of FSDP1 since we - # no longer need to explicitly name our param init method reset_parameters - def reset_parameters(self): - self.rope_init() - def rope_init(self): theta = 1.0 / ( self.base @@ -125,3 +120,145 @@ def forward( # tensor has shape [b, s, n_h, h_d] x_out = x_out.flatten(3) return x_out.type_as(x) + + +class VisionRotaryPositionalEmbeddings(nn.Module): + """ + This class implements two-dimensional Rotary Positional Embeddings (RoPE) for images + based on the axial frequency 2D RoPE described in https://arxiv.org/pdf/2403.13298. + + The position embedding is simply applied to the x-axis and y-axis separately. + + Note: This module assumes the CLS token embedding is appended at the end of the sequence. + + Args: + patch_size (int): The size of each patch. Used to divide the tiles into patches. + E.g. for ``patch_size=40``, a tile of shape (400, 400) will have 10x10 grid of patches. + tile_size (int): The size of your image tiles, if the image was tile-cropped in advance. Otherwise, + the size of the full input image. In this case, the function will consider your image as a single tile. + dim (int): Embedding dimension. Unlike :class:`~torchtune.modules.RotaryPositionalEmbeddings`, this is + usually set to the dim of each head in the attention module divided by 2, computed as + ``embed_dim // num_heads // 2``. The divide by 2 accounts for x and y positions. + base (int): The base for the geometric progression used to compute + the rotation angles + append_cls_token (bool): Set to True if CLS token embedding is at the end of the sequence in the vision transformer, + False if is in the beginning of the sequence. RoPE is zeroed out for the CLS token. Default is True. + """ + + def __init__( + self, + patch_size: int, + tile_size: int, + dim: int, + base: int = 10_000, + append_cls_token: bool = True, + ) -> None: + super().__init__() + self.patch_grid_size = tile_size // patch_size + self.dim = dim + self.base = base + self.append_cls_token = append_cls_token + self.rope_init() + + def rope_init(self): + theta = 1.0 / ( + self.base + ** (torch.arange(0, self.dim, 2)[: (self.dim // 2)].float() / self.dim) + ) + self.register_buffer("theta", theta, persistent=False) + self.build_rope_cache() + + def build_rope_cache(self) -> None: + # Create position indices for each patch in the tile + patches_per_tile = self.patch_grid_size**2 + patch_idx = torch.arange( + patches_per_tile, dtype=self.theta.dtype, device=self.theta.device + ) + # Add a placeholder index for CLS token - will not be used in RoPE + if self.append_cls_token: + patch_idx = torch.cat( + [ + patch_idx, + -1 * torch.ones(1, dtype=patch_idx.dtype, device=patch_idx.device), + ] + ) + else: + patch_idx = torch.cat( + [ + -1 * torch.ones(1, dtype=patch_idx.dtype, device=patch_idx.device), + patch_idx, + ] + ) + # Encode x and y positions of each patch in the tile + patch_x_pos = patch_idx % self.patch_grid_size + patch_y_pos = patch_idx // self.patch_grid_size + + # Outer product of theta and position index; output tensor has + # a shape of [patches_per_tile + 1, dim // 2] + x_theta = torch.einsum("i, j -> ij", patch_x_pos + 1, self.theta).float() + y_theta = torch.einsum("i, j -> ij", patch_y_pos + 1, self.theta).float() + + # Shape: [patches_per_tile + 1, dim] + freqs = torch.cat([x_theta, y_theta], dim=-1) + # Zero out CLS token position frequencies + freqs = freqs.masked_fill(patch_idx.unsqueeze(-1) < 0, 0) + + # cache includes both the cos and sin components and so the output shape is + # [patches_per_tile + 1, dim, 2] + cache = torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1) + self.register_buffer("cache", cache, persistent=False) + + def forward( + self, x: torch.Tensor, *, input_pos: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """ + Args: + x (torch.Tensor): input tensor with shape + ``[b, s, n_h, h_d]`` + input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids + of each token. During training, this is used to indicate the positions + of each token relative to its sample when packed, shape [b, s]. + During inference, this indicates the position of the current token. + If none, assume the index of the token is its position id. Default is None. + + Returns: + torch.Tensor: output tensor with shape ``[b, s, n_h, h_d]`` + + Notation used for tensor shapes: + - b: batch size + - s: sequence length + - n_h: num heads + - h_d: head dim + """ + # input tensor has shape [b, s, n_h, h_d] + seq_len = x.size(1) + + # extract the values based on whether input_pos is set or not + rope_cache = ( + self.cache[:seq_len] if input_pos is None else self.cache[input_pos] + ) + + # reshape input; the last dimension is used for computing the output. + # Cast to float to match the reference implementation + # tensor has shape [b, s, n_h, h_d // 2, 2] + xshaped = x.float().reshape(*x.shape[:-1], -1, 2) + + # reshape the cache for broadcasting + # tensor has shape [b, s, 1, h_d // 2, 2] if packed samples, + # otherwise has shape [1, s, 1, h_d // 2, 2] + rope_cache = rope_cache.view(-1, xshaped.size(1), 1, xshaped.size(3), 2) + + # tensor has shape [b, s, n_h, h_d // 2, 2] + x_out = torch.stack( + [ + xshaped[..., 0] * rope_cache[..., 0] + - xshaped[..., 1] * rope_cache[..., 1], + xshaped[..., 1] * rope_cache[..., 0] + + xshaped[..., 0] * rope_cache[..., 1], + ], + -1, + ) + + # tensor has shape [b, s, n_h, h_d] + x_out = x_out.flatten(3) + return x_out.type_as(x) diff --git a/torchtune/modules/vision_transformer.py b/torchtune/modules/vision_transformer.py index ea5c19f97d..d44d0c930f 100644 --- a/torchtune/modules/vision_transformer.py +++ b/torchtune/modules/vision_transformer.py @@ -186,6 +186,8 @@ class VisionTransformer(nn.Module): before they go through a next layer. For example, ``out_indices=[0,3]`` will return the tokens before they go through the first and fourth layers. in_channels (int): The number of image input channels. + append_cls_token (bool): If True, adds CLS token to the end of the sequence. + Default is False, which adds CLS token to the beginning of the sequence. Raises: ValueError: If `tile_size` is not greater than 0. @@ -206,6 +208,7 @@ def __init__( cls_projection: Optional[nn.Module] = None, out_indices: Optional[List[int]] = None, in_channels: int = 3, + append_cls_token: bool = False, ) -> None: super().__init__() @@ -245,7 +248,9 @@ def __init__( self.ln_post = Fp32LayerNorm(embed_dim) self.ln_pre = Fp32LayerNorm(embed_dim) - self.cls_token_embedding = CLSEmbedding(embed_dim) + self.cls_token_embedding = CLSEmbedding( + embed_dim, append_cls_token=append_cls_token + ) def get_image_tokens_per_tile(self): return self.patches_per_tile + 1 # +1 for CLS token @@ -415,20 +420,27 @@ class CLSEmbedding(nn.Module): Args: embed_dim (int): The dimensionality of the input patch embedding. + append_cls_token (bool): If True, adds CLS token to the end of the sequence. + Default is False, which adds CLS token to the beginning of the sequence. """ - def __init__(self, embed_dim: int) -> None: + def __init__(self, embed_dim: int, append_cls_token: bool = False) -> None: super().__init__() scale = embed_dim**-0.5 self.weight = nn.Parameter(scale * torch.randn(embed_dim)) + self.append_cls_token = append_cls_token def forward(self, x: torch.Tensor) -> torch.Tensor: # add 1 CLS token to every tile bsz_and_n_imgs, n_tiles, n_tokens, embed_dim = x.shape cls_emb = self.weight.broadcast_to(bsz_and_n_imgs, n_tiles, 1, embed_dim) - return torch.cat([cls_emb, x], dim=2) + return ( + torch.cat([x, cls_emb], dim=2) + if self.append_cls_token + else torch.cat([cls_emb, x], dim=2) + ) class CLSProjection(nn.Module):