Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CLIP Text Encoder #1969

Merged
merged 21 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions tests/torchtune/models/clip/test_clip_text_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest
import torch

from torchtune.models.clip._component_builders import clip_text_encoder
from torchtune.training.seed import set_seed

VOCAB_SIZE = 512
MAX_SEQ_LEN = 77
BSZ = 2
EMBED_DIM = 4


@pytest.fixture(autouse=True)
def random():
set_seed(0)


class TestClipTextEncoder:
@pytest.fixture
def model(self):
model = clip_text_encoder(
vocab_size=VOCAB_SIZE,
max_seq_len=MAX_SEQ_LEN,
embed_dim=EMBED_DIM,
num_heads=2,
num_layers=2,
)

for param in model.parameters():
param.data.uniform_(0, 1)

return model

@pytest.fixture
def inputs(self):
return torch.randint(0, VOCAB_SIZE, (BSZ, MAX_SEQ_LEN))

def test_forward(self, model, inputs):
actual = model(inputs)
expected = torch.tensor(
[[0.1915, 1.3982, 0.6298, -0.0966], [0.2276, 1.3785, 0.6309, -0.1066]]
)
assert actual.shape == (BSZ, EMBED_DIM)
torch.testing.assert_close(actual, expected, atol=1e-4, rtol=1e-4)

def test_backward(self, model, inputs):
y = model(inputs)
loss = y.mean()
loss.backward()
58 changes: 58 additions & 0 deletions tests/torchtune/models/clip/test_clip_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
calvinpelletier marked this conversation as resolved.
Show resolved Hide resolved
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import pytest

from tests.common import ASSETS
from torchtune.models.clip._model_builders import clip_tokenizer


class TestCLIPTokenizer:
@pytest.fixture
def tokenizer(self):
return clip_tokenizer(ASSETS / "tiny_bpe_merges.txt")

def test_encoding(self, tokenizer):
texts = [
"a cow jumping over the moon",
"a helpful AI assistant",
]
correct_tokens = [
[
2416,
320,
66,
78,
342,
73,
669,
79,
515,
326,
1190,
337,
673,
324,
76,
819,
333,
2417,
],
[2416, 320, 516, 75, 79, 69, 84, 331, 64, 328, 813, 667, 540, 339, 2417],
]
for text, correct in zip(texts, correct_tokens):
tokens = tokenizer.encode(text)
assert tokens == correct

def test_decoding(self, tokenizer):
text = "this is torchtune"
decoded_text = "<|startoftext|>this is torchtune <|endoftext|>"
assert decoded_text == tokenizer.decode(tokenizer.encode(text))

def test_call(self, tokenizer):
sample = {"text": "hello world"}
sample = tokenizer(sample)
assert "text" not in sample
assert "tokens" in sample
7 changes: 5 additions & 2 deletions torchtune/models/clip/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from ._component_builders import clip_mlp, clip_vision_encoder

from ._component_builders import clip_mlp, clip_text_encoder, clip_vision_encoder
from ._model_builders import clip_text_vit_large_patch14, clip_tokenizer
from ._position_embeddings import (
TiledTokenPositionalEmbedding,
TilePositionalEmbedding,
Expand All @@ -15,7 +15,10 @@

__all__ = [
"clip_mlp",
"clip_text_encoder",
"clip_vision_encoder",
"clip_text_vit_large_patch14",
"clip_tokenizer",
"CLIPImageTransform",
"TokenPositionalEmbedding",
"TiledTokenPositionalEmbedding",
Expand Down
67 changes: 62 additions & 5 deletions torchtune/models/clip/_component_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
from typing import Callable, List, Optional

from torch import nn

from torchtune.models.clip._position_embeddings import (
TiledTokenPositionalEmbedding,
TilePositionalEmbedding,
TokenPositionalEmbedding,
)

from torchtune.models.clip._text_encoder import CLIPTextEncoder, QuickGELU
from torchtune.modules import (
FeedForward,
Fp32LayerNorm,
Expand All @@ -22,11 +23,8 @@
TransformerSelfAttentionLayer,
VisionRotaryPositionalEmbeddings,
)

from torchtune.modules.common_utils import reparametrize_as_dtype_state_dict_post_hook

from torchtune.modules.peft import DoRALinear, LORA_ATTN_MODULES, LoRALinear

from torchtune.modules.peft import LORA_ATTN_MODULES, DoRALinear, LoRALinear
from torchtune.modules.vision_transformer import CLSProjection, VisionTransformer


Expand Down Expand Up @@ -181,6 +179,65 @@ def clip_vision_encoder(
)


def clip_text_encoder(
embed_dim: int,
num_heads: int,
num_layers: int,
vocab_size: int = 49408,
calvinpelletier marked this conversation as resolved.
Show resolved Hide resolved
max_seq_len: int = 77,
norm_eps: float = 1e-5,
):
"""
Text encoder for CLIP.

CLIP is a model that encodes text and images into a shared vector space.
Blog post: https://openai.com/index/clip/
Paper: https://arxiv.org/abs/2103.00020

Args:
embed_dim (int): embedding/model dimension size
num_heads (int): number of attention heads
num_layers (int): number of transformer layers
vocab_size (int): size of the vocabulary, default 49408
max_seq_len (int): context size, default 77
norm_eps (float): small value added to denominator for numerical stability, default 1e-5

Returns:
CLIPTextEncoder
"""
attn = MultiHeadAttention(
embed_dim=embed_dim,
num_heads=num_heads,
num_kv_heads=num_heads,
head_dim=embed_dim // num_heads,
q_proj=nn.Linear(embed_dim, embed_dim),
k_proj=nn.Linear(embed_dim, embed_dim),
v_proj=nn.Linear(embed_dim, embed_dim),
output_proj=nn.Linear(embed_dim, embed_dim),
)
mlp = clip_mlp(
in_dim=embed_dim,
out_dim=embed_dim,
hidden_dim=embed_dim * 4,
activation=QuickGELU(),
)
encoder_layer = TransformerSelfAttentionLayer(
attn=attn,
mlp=mlp,
sa_norm=nn.LayerNorm(embed_dim, eps=norm_eps),
mlp_norm=nn.LayerNorm(embed_dim, eps=norm_eps),
)
final_norm = nn.LayerNorm(embed_dim, eps=norm_eps)
return CLIPTextEncoder(
layer=encoder_layer,
final_norm=final_norm,
vocab_size=vocab_size,
max_seq_len=max_seq_len,
embed_dim=embed_dim,
num_layers=num_layers,
)


def clip_mlp(
in_dim: int,
out_dim: int,
Expand Down
48 changes: 48 additions & 0 deletions torchtune/models/clip/_convert_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtune.models.convert_weights import get_mapped_key

# state dict key mappings from HF's format to torchtune's format
_FROM_HF = {
"text_model.embeddings.token_embedding.weight": "token_embedding.weight",
"text_model.embeddings.position_embedding.weight": "position_embedding",
"text_model.encoder.layers.{}.layer_norm1.weight": "layers.{}.sa_norm.weight",
"text_model.encoder.layers.{}.layer_norm1.bias": "layers.{}.sa_norm.bias",
"text_model.encoder.layers.{}.layer_norm2.weight": "layers.{}.mlp_norm.weight",
"text_model.encoder.layers.{}.layer_norm2.bias": "layers.{}.mlp_norm.bias",
"text_model.encoder.layers.{}.mlp.fc1.weight": "layers.{}.mlp.w1.weight",
"text_model.encoder.layers.{}.mlp.fc1.bias": "layers.{}.mlp.w1.bias",
"text_model.encoder.layers.{}.mlp.fc2.weight": "layers.{}.mlp.w2.weight",
"text_model.encoder.layers.{}.mlp.fc2.bias": "layers.{}.mlp.w2.bias",
"text_model.encoder.layers.{}.self_attn.q_proj.weight": "layers.{}.attn.q_proj.weight",
"text_model.encoder.layers.{}.self_attn.q_proj.bias": "layers.{}.attn.q_proj.bias",
"text_model.encoder.layers.{}.self_attn.k_proj.weight": "layers.{}.attn.k_proj.weight",
"text_model.encoder.layers.{}.self_attn.k_proj.bias": "layers.{}.attn.k_proj.bias",
"text_model.encoder.layers.{}.self_attn.v_proj.weight": "layers.{}.attn.v_proj.weight",
"text_model.encoder.layers.{}.self_attn.v_proj.bias": "layers.{}.attn.v_proj.bias",
"text_model.encoder.layers.{}.self_attn.out_proj.bias": "layers.{}.attn.output_proj.bias",
"text_model.encoder.layers.{}.self_attn.out_proj.weight": "layers.{}.attn.output_proj.weight",
"text_model.final_layer_norm.weight": "final_norm.weight",
"text_model.final_layer_norm.bias": "final_norm.bias",
}

_IGNORE = {
"logit_scale",
"text_model.embeddings.position_ids",
"text_projection.weight",
"visual_projection.weight",
}


def clip_text_hf_to_tune(state_dict):
converted_state_dict = {}
for key, value in state_dict.items():
if key.startswith("vision_model.") or key in _IGNORE:
continue
new_key = get_mapped_key(key, _FROM_HF)
converted_state_dict[new_key] = value
return converted_state_dict
52 changes: 51 additions & 1 deletion torchtune/models/clip/_model_builders.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,54 @@
from torchtune.models.clip._transforms import CLIPImageTransform
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from torchtune.models.clip._component_builders import clip_text_encoder
from torchtune.models.clip._text_encoder import CLIPTextEncoder
from torchtune.models.clip._tokenizer import CLIPTokenizer
from torchtune.models.clip._transform import CLIPImageTransform


def clip_tokenizer(
path: str,
calvinpelletier marked this conversation as resolved.
Show resolved Hide resolved
max_seq_len: int = 77,
truncate: bool = True,
) -> CLIPTokenizer:
"""
Builder for the CLIP text tokenizer.

Args:
path (str): Path to the CLIP merges file
max_seq_len (bool): Context length. Default: 77
truncate (bool): Truncate the token sequence if it exceeds max_seq_len (otherwise raises AssertionError)
Default: True

Returns:
CLIPTokenizer: Instantiation of the CLIP text tokenizer
"""
return CLIPTokenizer(path, max_seq_len=max_seq_len, truncate=truncate)


def clip_text_vit_large_patch14() -> CLIPTextEncoder:
"""
Builder for the CLIP text encoder for CLIP-ViT-L/14.

calvinpelletier marked this conversation as resolved.
Show resolved Hide resolved
CLIP is a model that encodes text and images into a shared vector space.
Blog post: https://openai.com/index/clip/
Paper: https://arxiv.org/abs/2103.00020

Returns:
CLIPTextEncoder: Instantiation of the CLIP text encoder
"""
return clip_text_encoder(
embed_dim=768,
num_heads=12,
num_layers=12,
vocab_size=49408,
max_seq_len=77,
norm_eps=1e-5,
)


def clip_vit_224_transform():
image_transform = CLIPImageTransform(
Expand Down
Loading
Loading