Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

T5 Encoder #2069

Merged
merged 12 commits into from
Jan 7, 2025
Binary file added tests/assets/sentencepiece.model
Binary file not shown.
5 changes: 5 additions & 0 deletions tests/torchtune/models/t5/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
81 changes: 81 additions & 0 deletions tests/torchtune/models/t5/test_t5_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest
import torch

from torchtune.models.t5._component_builders import t5_encoder
from torchtune.training.seed import set_seed

VOCAB_SIZE = 512
MAX_SEQ_LEN = 8
BSZ = 2
EMBED_DIM = 2


@pytest.fixture(autouse=True)
def random():
set_seed(0)


class TestT5Encoder:
@pytest.fixture
def model(self):
model = t5_encoder(
embed_dim=EMBED_DIM,
mlp_dim=4,
num_heads=2,
head_dim=EMBED_DIM // 2,
num_layers=2,
rel_pos_num_buckets=4,
rel_pos_max_dist=4,
vocab_size=VOCAB_SIZE,
norm_eps=1e-6,
max_seq_len=MAX_SEQ_LEN,
)

for param in model.parameters():
param.data.uniform_(0, 1)

return model

@pytest.fixture
def inputs(self):
return torch.randint(0, VOCAB_SIZE, (BSZ, MAX_SEQ_LEN))

def test_forward(self, model, inputs):
actual = model(inputs)
expected = torch.tensor(
[
[
[0.3670, 0.2938],
[0.3692, 0.2921],
[0.3611, 0.2984],
[0.4207, 0.2437],
[0.3447, 0.3106],
[0.3383, 0.3150],
[0.3727, 0.2892],
[0.3996, 0.2653],
],
[
[0.3855, 0.2783],
[0.2627, 0.3581],
[0.3601, 0.2992],
[0.3473, 0.3087],
[0.3549, 0.3032],
[0.2871, 0.3459],
[0.2753, 0.3520],
[0.2285, 0.3728],
],
]
)
assert actual.shape == (BSZ, MAX_SEQ_LEN, EMBED_DIM)
torch.testing.assert_close(actual, expected, atol=1e-4, rtol=1e-4)

def test_backward(self, model, inputs):
y = model(inputs)
loss = y.mean()
loss.backward()
40 changes: 40 additions & 0 deletions tests/torchtune/models/t5/test_t5_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import pytest

from tests.common import ASSETS
from torchtune.models.t5._model_builders import t5_tokenizer


class TestT5Tokenizer:
@pytest.fixture
def tokenizer(self):
return t5_tokenizer(str(ASSETS / "sentencepiece.model"))

def test_encoding(self, tokenizer):
texts = [
"a cow jumping over the moon",
"a helpful AI assistant",
]
correct_tokens = [
[3, 9, 9321, 15539, 147, 8, 8114, 1],
[3, 9, 2690, 7833, 6165, 1],
]
for text, correct in zip(texts, correct_tokens):
tokens = tokenizer.encode(text)
print(tokens)
assert tokens == correct

def test_decoding(self, tokenizer):
text = "this is torchtune"
decoded_text = "<|startoftext|>this is torchtune <|endoftext|>"
assert text == tokenizer.decode(tokenizer.encode(text))

def test_call(self, tokenizer):
sample = {"text": "hello world"}
sample = tokenizer(sample)
assert "text" not in sample
assert "tokens" in sample
12 changes: 12 additions & 0 deletions torchtune/models/t5/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from ._model_builders import t5_tokenizer, t5_v1p1_xxl_encoder
calvinpelletier marked this conversation as resolved.
Show resolved Hide resolved

__all__ = [
"t5_tokenizer",
"t5_v1p1_xxl_encoder",
]
89 changes: 89 additions & 0 deletions torchtune/models/t5/_component_builders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from torch import nn

from torchtune.models.t5._encoder import (
T5Encoder,
T5EncoderLayer,
T5EncoderSelfAttention,
)
from torchtune.modules.feed_forward import FeedForward
from torchtune.modules.rms_norm import RMSNorm


def t5_encoder(
embed_dim: int,
mlp_dim: int,
num_heads: int,
head_dim: int,
num_layers: int,
rel_pos_num_buckets: int,
rel_pos_max_dist: int,
vocab_size: int,
norm_eps: float,
max_seq_len: int,
):
"""
Builder for the T5 encoder.

T5 paper: https://arxiv.org/abs/1910.10683

Args:
embed_dim (int): The model dimension.
mlp_dim (int): The inner dimension of the feed forward layers.
num_heads (int): The number of attention heads.
head_dim (int): The dimension of the attention heads (should equal `embed_dim // num_heads`)
num_layers (int): Number of encoder layers.
rel_pos_num_buckets (int): Number of discrete buckets to divide the relative positions into.
See: :class:`~torchtune.models.t5._encoder.T5EncoderRelativePositionBias`
rel_pos_max_dist (int): Maximum distance for relative positions.
Distances beyond this are grouped into the last bucket.
See: :class:`~torchtune.models.t5._encoder.T5EncoderRelativePositionBias`
vocab_size (int): Vocab size of the model's tokenizer.
norm_eps (float): Small value added to denominator for numerical stability.
max_seq_len (int): The maximum sequence length (context length) of the model.

Returns:
T5Encoder
"""
token_embedding = nn.Embedding(vocab_size, embed_dim)

attn = T5EncoderSelfAttention(
embed_dim=embed_dim,
num_heads=num_heads,
head_dim=head_dim,
q_proj=nn.Linear(embed_dim, embed_dim, bias=False),
k_proj=nn.Linear(embed_dim, embed_dim, bias=False),
v_proj=nn.Linear(embed_dim, embed_dim, bias=False),
output_proj=nn.Linear(embed_dim, embed_dim, bias=False),
)

mlp = FeedForward(
gate_proj=nn.Linear(embed_dim, mlp_dim, bias=False),
down_proj=nn.Linear(mlp_dim, embed_dim, bias=False),
up_proj=nn.Linear(embed_dim, mlp_dim, bias=False),
activation=nn.GELU(),
)

layer = T5EncoderLayer(
attn=attn,
mlp=mlp,
sa_norm=RMSNorm(embed_dim, eps=norm_eps),
mlp_norm=RMSNorm(embed_dim, eps=norm_eps),
)

final_norm = RMSNorm(embed_dim, eps=norm_eps)

return T5Encoder(
token_embedding=token_embedding,
layer=layer,
final_norm=final_norm,
num_layers=num_layers,
num_heads=num_heads,
rel_pos_num_buckets=rel_pos_num_buckets,
rel_pos_max_dist=rel_pos_max_dist,
max_seq_len=max_seq_len,
)
49 changes: 49 additions & 0 deletions torchtune/models/t5/_convert_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtune.models.convert_weights import get_mapped_key

# state dict key mappings from HF's format to torchtune's format
_FROM_HF = {
# emb
"encoder.embed_tokens.weight": "token_embedding.weight",
"encoder.block.{}.layer._0.SelfAttention.relative_attention_bias.weight": "relative_position_bias.embedding.weight",
# attn
"encoder.block.{}.layer._0.SelfAttention.q.weight": "layers.{}.attn.q_proj.weight",
"encoder.block.{}.layer._0.SelfAttention.k.weight": "layers.{}.attn.k_proj.weight",
"encoder.block.{}.layer._0.SelfAttention.v.weight": "layers.{}.attn.v_proj.weight",
"encoder.block.{}.layer._0.SelfAttention.o.weight": "layers.{}.attn.output_proj.weight",
# ff
"encoder.block.{}.layer._1.DenseReluDense.wi_0.weight": "layers.{}.mlp.w1.weight",
"encoder.block.{}.layer._1.DenseReluDense.wo.weight": "layers.{}.mlp.w2.weight",
"encoder.block.{}.layer._1.DenseReluDense.wi_1.weight": "layers.{}.mlp.w3.weight",
# norm
"encoder.block.{}.layer._0.layer_norm.weight": "layers.{}.sa_norm.scale",
"encoder.block.{}.layer._1.layer_norm.weight": "layers.{}.mlp_norm.scale",
"encoder.final_layer_norm.weight": "final_norm.scale",
}

_IGNORE = {
"shared.weight",
"lm_head.weight",
}


def t5_encoder_hf_to_tune(state_dict):
converted_state_dict = {}
for key, value in state_dict.items():
if key.startswith("decoder.") or key in _IGNORE:
continue

# NOTE: HF's T5 has ".<integer>." parts that we do NOT want to be dynamically mapped
# to corresponding ".<integer>." parts in our converted state dict.
# This breaks the `get_mapped_key` implementation, so as a temporary hack,
# we add leading underscores to these parts here and in the `_FROM_HF` map above.
key = key.replace("layer.0.", "layer._0.").replace("layer.1.", "layer._1.")

new_key = get_mapped_key(key, _FROM_HF)
converted_state_dict[new_key] = value
return converted_state_dict
Loading
Loading