Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CLIP Text Encoder #1969

Merged
merged 21 commits into from
Nov 20, 2024
Merged

CLIP Text Encoder #1969

merged 21 commits into from
Nov 20, 2024

Conversation

calvinpelletier
Copy link
Contributor

@calvinpelletier calvinpelletier commented Nov 7, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses.

Changelog

What are the changes made in this PR?

  • Implement the CLIP tokenizer
  • Implement the CLIP text encoder

Test plan

Minimal code to run the CLIP text encoder e2e:
(first download CLIP weights: tune download openai/clip-vit-large-patch14 --output-dir /tmp/clip-vit-large-patch14 --ignore-patterns None)

import torch

from torchtune.models.clip._model_builders import clip_tokenizer, clip_text_encoder_large
from torchtune.training.checkpointing._checkpointer import FullModelHFCheckpointer

# build tokenizer and text encoder
tokenizer = clip_tokenizer()
encoder = clip_text_encoder_large()

# load weights
encoder.load_state_dict(FullModelHFCheckpointer(
    "/tmp/clip-vit-large-patch14",
    ["model.safetensors"],
    "CLIP_TEXT",
    "/tmp/torchtune-clip-vit-large-patch14",
).load_checkpoint()["model"])
encoder = encoder.to(torch.bfloat16).cuda().eval()

# run
text = [
    "a cow jumping over the moon",
    "a helpful AI assistant",
]
tokens = tokenizer(text)
encoding = encoder(tokens.cuda())

Checked parity with the HF CLIP tokenizer and text encoder as implemented here: MSE between the encoder outputs for on a batch of 32 test strings = 3.55e-5

Tokenization speed for 32 test strings

  • OpenAI: 0.0342
  • HuggingFace: 0.0416
  • TorchTune: 0.0195

Tokenization speed for an entire 100k img gen prompt dataset:

  • HuggingFace: 18.02
  • TorchTune: 6.89

Encoding speed for a single batch of 32 test strings:

  • HuggingFace: 1.699
  • TorchTune: 0.0173

Encoding speed for 1000 batches of 32 test strings:

  • HuggingFace: 30.498
  • TorchTune: 11.713

Checklist

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
    • TODO: unit test for encoder
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API
  • I have added an example to docs or docstrings

Copy link

pytorch-bot bot commented Nov 7, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1969

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit c215690 with merge base fcd400f (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 7, 2024
tests/torchtune/models/clip/test_clip_tokenizer.py Outdated Show resolved Hide resolved
torchtune/models/clip/_model_builders.py Show resolved Hide resolved
torchtune/models/clip/_text_encoder.py Outdated Show resolved Hide resolved
torchtune/models/clip/_tokenizer.py Show resolved Hide resolved
torchtune/models/clip/_tokenizer.py Outdated Show resolved Hide resolved
@codecov-commenter
Copy link

codecov-commenter commented Nov 7, 2024

Codecov Report

Attention: Patch coverage is 92.30769% with 17 lines in your changes missing coverage. Please review.

Project coverage is 65.08%. Comparing base (1814feb) to head (c215690).
Report is 5 commits behind head on main.

Files with missing lines Patch % Lines
torchtune/models/clip/_convert_weights.py 36.36% 7 Missing ⚠️
torchtune/models/clip/_tokenizer.py 93.45% 7 Missing ⚠️
torchtune/models/clip/_model_builders.py 87.50% 1 Missing ⚠️
torchtune/models/clip/_text_encoder.py 96.42% 1 Missing ⚠️
torchtune/training/checkpointing/_checkpointer.py 75.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1969      +/-   ##
==========================================
- Coverage   67.29%   65.08%   -2.21%     
==========================================
  Files         318      323       +5     
  Lines       17646    17848     +202     
==========================================
- Hits        11874    11616     -258     
- Misses       5772     6232     +460     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.


🚨 Try these New Features:

Copy link
Contributor

@pbontrager pbontrager left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a great PR! I left some comments around some standard patterns we try to follow but aside from that, this looks very solid and clean.

torchtune/models/clip/_tokenizer.py Outdated Show resolved Hide resolved
torchtune/modules/activations.py Outdated Show resolved Hide resolved
torchtune/utils/_download.py Outdated Show resolved Hide resolved
torchtune/models/clip/_tokenizer.py Outdated Show resolved Hide resolved
torchtune/models/clip/_model_builders.py Outdated Show resolved Hide resolved
torchtune/models/clip/_model_builders.py Outdated Show resolved Hide resolved
tests/torchtune/models/clip/test_clip_tokenizer.py Outdated Show resolved Hide resolved
tests/torchtune/models/clip/test_clip_tokenizer.py Outdated Show resolved Hide resolved
Copy link
Contributor

@pbontrager pbontrager left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd still like the model type to be unified to just CLIP to keep the list shorter. But I'll approve when @RdoubleA signs off on the tokenizer.

torchtune/models/clip/_model_builders.py Show resolved Hide resolved
torchtune/models/clip/_text_encoder.py Outdated Show resolved Hide resolved
torchtune/modules/activations.py Outdated Show resolved Hide resolved
torchtune/models/clip/_tokenizer.py Outdated Show resolved Hide resolved
torchtune/models/clip/_tokenizer.py Outdated Show resolved Hide resolved
torchtune/models/clip/_tokenizer.py Outdated Show resolved Hide resolved
torchtune/models/clip/_tokenizer.py Outdated Show resolved Hide resolved
torchtune/models/clip/_tokenizer.py Show resolved Hide resolved
torchtune/models/clip/_tokenizer.py Show resolved Hide resolved
torchtune/models/clip/_tokenizer.py Outdated Show resolved Hide resolved
torchtune/models/clip/_tokenizer.py Outdated Show resolved Hide resolved
torchtune/models/clip/_tokenizer.py Outdated Show resolved Hide resolved
torchtune/models/clip/_tokenizer.py Show resolved Hide resolved
torchtune/models/clip/_component_builders.py Show resolved Hide resolved
torchtune/models/clip/_model_builders.py Show resolved Hide resolved
torchtune/models/clip/_model_builders.py Outdated Show resolved Hide resolved

class CLIPTextEncoder(nn.Module):
"""
Text encoder for CLIP.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mind briefly describing the architecture? probably a normal transformer, but uses a different MLP activation?

Copy link
Contributor Author

@calvinpelletier calvinpelletier Nov 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's just a normal transformer. the MLP activation is still GELU, it's just a faster and less-precise version of it (tho not actually faster these days, this is a relic of the ancient year 2021)

torchtune/models/clip/_text_encoder.py Show resolved Hide resolved
torchtune/models/clip/_text_encoder.py Show resolved Hide resolved
torchtune/models/clip/_text_encoder.py Outdated Show resolved Hide resolved
Copy link
Contributor

@pbontrager pbontrager left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good now, thanks for doing all the changes! I think you could add a few line description of the models in the docstrings where Rafi commented along with an arxiv link. But I'll approve it now.

# [b, s, d] -> [b, d]
# TODO: handle the case when the EOS token is not the highest token ID
eos_token_positions = tokens.argmax(dim=-1)
eos_token_positions = (tokens == self.eot_token).int().argmax(dim=-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can do (tokens == self.eot_token).nonzero()

Copy link
Contributor Author

@calvinpelletier calvinpelletier Nov 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't want all positions of the eot token, just the first one (argmax gives the first position where they match)

@calvinpelletier calvinpelletier merged commit 89f935f into pytorch:main Nov 20, 2024
17 checks passed
@ebsmothers ebsmothers mentioned this pull request Nov 26, 2024
44 tasks
@calvinpelletier calvinpelletier deleted the clip_text branch December 8, 2024 18:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants