Skip to content

Commit

Permalink
2D RoPE + CLIP updates (#1973)
Browse files Browse the repository at this point in the history
  • Loading branch information
RdoubleA authored Nov 17, 2024
1 parent 0c31907 commit bce7091
Show file tree
Hide file tree
Showing 11 changed files with 472 additions and 100 deletions.
60 changes: 60 additions & 0 deletions tests/torchtune/models/phi3/test_phi3_position_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest
import torch

from tests.test_utils import assert_expected, mps_ignored_test
from torch import tensor
from torchtune.models.phi3 import Phi3RotaryPositionalEmbeddings

from torchtune.training.seed import set_seed


@pytest.fixture(autouse=True)
def random():
set_seed(0)


class TestPhi3RotaryPositionalEmbeddings:
"""
Class for testing the Phi3 models RoPE Embeddings. The expected tensors are
computed from the reference implementation here:
https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/modeling_phi3.py
"""

@pytest.fixture
def input_params(self):
bsz = 4
num_heads = 32
embed_dim = 3072
seq_len = 60
max_seq_len = 4096
head_dim = embed_dim // num_heads
return bsz, num_heads, head_dim, seq_len, max_seq_len

@pytest.fixture
def input(self, input_params) -> tensor:
bsz, num_heads, head_dim, seq_len, _ = input_params
return torch.randn(bsz, seq_len, num_heads, head_dim)

@pytest.fixture
def rope_phi3(self, input_params) -> Phi3RotaryPositionalEmbeddings:
_, _, head_dim, _, max_seq_len = input_params
return Phi3RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len)

@mps_ignored_test()
def test_forward(
self, input: tensor, rope_phi3: Phi3RotaryPositionalEmbeddings
) -> None:
x_out = rope_phi3(input)

# check the numerics of the computed tensor
assert_expected(x_out.mean(), tensor(-0.0005), atol=1e-4)
assert_expected(x_out.sum(), tensor(-381.0620))

# check shapes
assert_expected(x_out.shape, input.shape)
130 changes: 95 additions & 35 deletions tests/torchtune/modules/test_position_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple

import pytest
import torch

from tests.test_utils import assert_expected, mps_ignored_test
from torch import tensor
from torchtune.models.phi3 import Phi3RotaryPositionalEmbeddings

from torchtune.modules.position_embeddings import RotaryPositionalEmbeddings
from torchtune.modules.position_embeddings import (
RotaryPositionalEmbeddings,
VisionRotaryPositionalEmbeddings,
)
from torchtune.training.seed import set_seed


Expand All @@ -35,7 +35,7 @@ class TestRotaryPositionEmbedding:
EXPECTED_X_OUT_MAX = tensor(5.4546)

@pytest.fixture
def input_params(self) -> Tuple[int, int, int, int]:
def input_params(self):
bsz = 4
num_heads = 32
embed_dim = 4096
Expand All @@ -45,14 +45,12 @@ def input_params(self) -> Tuple[int, int, int, int]:
return bsz, num_heads, head_dim, seq_len, max_seq_len

@pytest.fixture
def input(self, input_params: Tuple[int, int, int, int]) -> tensor:
def input(self, input_params) -> tensor:
bsz, num_heads, head_dim, seq_len, _ = input_params
return torch.randn(bsz, seq_len, num_heads, head_dim)

@pytest.fixture
def rope(
self, input_params: Tuple[int, int, int, int]
) -> RotaryPositionalEmbeddings:
def rope(self, input_params) -> RotaryPositionalEmbeddings:
_, _, head_dim, _, max_seq_len = input_params
return RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len)

Expand Down Expand Up @@ -136,44 +134,106 @@ def test_rope_init_meta_device(self, input_params):
torch.testing.assert_close(p1, p2)


class TestPhi3RotaryPositionalEmbeddings:
"""
Class for testing the Phi3 models RoPE Embeddings. The expected tensors are
computed from the reference implementation here:
https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/modeling_phi3.py
"""
class TestVisionRotaryPositionEmbedding:

EXPECTED_X_OUT_MEAN = tensor(0.0789793)
EXPECTED_X_OUT_SUM = tensor(25.2733822)
EXPECTED_X_OUT_MAX = tensor(3.1225626)

@pytest.fixture
def input_params(self) -> Tuple[int, int, int, int]:
bsz = 4
num_heads = 32
embed_dim = 3072
seq_len = 60
max_seq_len = 4096
def input_params(self):
bsz = 2
num_heads = 8
embed_dim = 32
head_dim = embed_dim // num_heads
return bsz, num_heads, head_dim, seq_len, max_seq_len
seq_len = 5
patch_size = 4
tile_size = 16
return bsz, num_heads, head_dim, seq_len, patch_size, tile_size

@pytest.fixture
def input(self, input_params: Tuple[int, int, int, int]) -> tensor:
bsz, num_heads, head_dim, seq_len, _ = input_params
def input(self, input_params) -> tensor:
bsz, num_heads, head_dim, seq_len, *_ = input_params
return torch.randn(bsz, seq_len, num_heads, head_dim)

@pytest.fixture
def rope_phi3(
self, input_params: Tuple[int, int, int, int]
) -> Phi3RotaryPositionalEmbeddings:
_, _, head_dim, _, max_seq_len = input_params
return Phi3RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len)
def rope(self, input_params):
_, _, head_dim, _, patch_size, tile_size = input_params
return VisionRotaryPositionalEmbeddings(
patch_size=patch_size, tile_size=tile_size, dim=head_dim // 2
)

@mps_ignored_test()
def test_forward(
self, input: tensor, rope_phi3: Phi3RotaryPositionalEmbeddings
) -> None:
x_out = rope_phi3(input)
def test_forward(self, input, rope) -> None:
x_out = rope(input)

# check the numerics of the computed tensor
assert_expected(x_out.mean(), tensor(-0.0005), atol=1e-4)
assert_expected(x_out.sum(), tensor(-381.0620))
assert_expected(x_out.mean(), self.EXPECTED_X_OUT_MEAN)
assert_expected(x_out.sum(), self.EXPECTED_X_OUT_SUM)
assert_expected(x_out.max(), self.EXPECTED_X_OUT_MAX)

# check shapes
assert_expected(x_out.shape, input.shape)

@mps_ignored_test()
def test_forward_with_curr_pos(self, input, rope) -> None:
(
_,
seq_len,
_,
_,
) = input.shape
x_out = rope(input, input_pos=torch.arange(seq_len))

# these values should be exactly the same as test_forward
# since in this case input_pos covers the entire input
# sequence. This tests that input_pos works as expected i.e.
# extracts the embeddings for the relevant positions
assert_expected(x_out.mean(), self.EXPECTED_X_OUT_MEAN, atol=1e-4)
assert_expected(x_out.sum(), self.EXPECTED_X_OUT_SUM)
assert_expected(x_out.max(), self.EXPECTED_X_OUT_MAX)

# check shapes
assert_expected(x_out.shape, input.shape)

@mps_ignored_test()
def test_forward_with_packed_pos(self, input, rope) -> None:
"""
Use input_pos to indicate positions of each token relative to its sequence
when sample is packed.
"""
(
bsz,
seq_len,
_,
_,
) = input.shape
x_out = rope(
input, input_pos=torch.arange(seq_len).unsqueeze(0).expand(bsz, seq_len)
)

# these values should be exactly the same as test_forward
# AND test_forward_with_current_pos. In this case input_pos
# covers the entire batch dim and is defined for each sample separately.
# This tests that input_pos works as expected i.e.
# extracts the embeddings for the relevant positions for each sample
assert_expected(x_out.mean(), self.EXPECTED_X_OUT_MEAN, atol=1e-4)
assert_expected(x_out.sum(), self.EXPECTED_X_OUT_SUM)
assert_expected(x_out.max(), self.EXPECTED_X_OUT_MAX)

# check shapes
assert_expected(x_out.shape, input.shape)

def test_rope_init_meta_device(self, input_params):
_, _, head_dim, _, patch_size, tile_size = input_params
rope_on_device = VisionRotaryPositionalEmbeddings(
dim=head_dim, patch_size=patch_size, tile_size=tile_size
)
with torch.device("meta"):
meta_rope = VisionRotaryPositionalEmbeddings(
dim=head_dim, patch_size=patch_size, tile_size=tile_size
)

meta_rope.rope_init()
for p1, p2 in zip(rope_on_device.buffers(), meta_rope.buffers()):
torch.testing.assert_close(p1, p2)
24 changes: 24 additions & 0 deletions tests/torchtune/modules/test_vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,3 +210,27 @@ def test_vision_transformer_single_tile(self, transformer_config):
), f"Expected shape {expected_shape}, but got {output.shape}"

assert_expected(output.mean(), torch.tensor(0.5458), atol=1e-3, rtol=1e-3)

@torch.no_grad()
def test_vision_transformer_append_cls_token(self, transformer_config):
transformer_config = transformer_config.copy()
transformer_config["append_cls_token"] = True

model_append_cls = clip_vision_encoder(**transformer_config).eval()
fixed_init_model(model_append_cls, min_val=-1, max_val=1)
output, _ = model_append_cls(self.image, self.aspect_ratio)

# assertion
expected_shape = (
self.batch_size,
self.n_imgs,
self.num_tiles,
model_append_cls.get_image_tokens_per_tile(),
transformer_config["embed_dim"],
)

assert (
output.shape == expected_shape
), f"Expected shape {expected_shape}, but got {output.shape}"

assert_expected(output.mean(), torch.tensor(1.0172), atol=1e-3, rtol=1e-3)
6 changes: 4 additions & 2 deletions tests/torchtune/training/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,10 @@ def _test_qlora_state_dict(self, enable_activation_checkpointing: bool):
)
# init rope since it's not covered in state dict
for m in fsdp_model_to_load.modules():
if isinstance(m, modules.RotaryPositionalEmbeddings):
m.reset_parameters()
if isinstance(m, modules.RotaryPositionalEmbeddings) or isinstance(
m, modules.VisionRotaryPositionalEmbeddings
):
m.rope_init()
for m in fsdp_model_to_load.modules():
if enable_activation_checkpointing:
if isinstance(m, CheckpointWrapper):
Expand Down
Loading

0 comments on commit bce7091

Please sign in to comment.