-
Notifications
You must be signed in to change notification settings - Fork 471
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CLIP Text Encoder #1969
CLIP Text Encoder #1969
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1969
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit c215690 with merge base fcd400f (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1969 +/- ##
==========================================
- Coverage 67.29% 65.08% -2.21%
==========================================
Files 318 323 +5
Lines 17646 17848 +202
==========================================
- Hits 11874 11616 -258
- Misses 5772 6232 +460 ☔ View full report in Codecov by Sentry. 🚨 Try these New Features:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a great PR! I left some comments around some standard patterns we try to follow but aside from that, this looks very solid and clean.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd still like the model type to be unified to just CLIP to keep the list shorter. But I'll approve when @RdoubleA signs off on the tokenizer.
|
||
class CLIPTextEncoder(nn.Module): | ||
""" | ||
Text encoder for CLIP. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mind briefly describing the architecture? probably a normal transformer, but uses a different MLP activation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's just a normal transformer. the MLP activation is still GELU, it's just a faster and less-precise version of it (tho not actually faster these days, this is a relic of the ancient year 2021)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good now, thanks for doing all the changes! I think you could add a few line description of the models in the docstrings where Rafi commented along with an arxiv link. But I'll approve it now.
# [b, s, d] -> [b, d] | ||
# TODO: handle the case when the EOS token is not the highest token ID | ||
eos_token_positions = tokens.argmax(dim=-1) | ||
eos_token_positions = (tokens == self.eot_token).int().argmax(dim=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can do (tokens == self.eot_token).nonzero()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't want all positions of the eot token, just the first one (argmax gives the first position where they match)
Context
What is the purpose of this PR? Is it to
Please link to any issues this PR addresses.
Changelog
What are the changes made in this PR?
Test plan
Minimal code to run the CLIP text encoder e2e:
(first download CLIP weights:
tune download openai/clip-vit-large-patch14 --output-dir /tmp/clip-vit-large-patch14 --ignore-patterns None
)Checked parity with the HF CLIP tokenizer and text encoder as implemented here: MSE between the encoder outputs for on a batch of 32 test strings =
3.55e-5
Tokenization speed for 32 test strings
Tokenization speed for an entire 100k img gen prompt dataset:
Encoding speed for a single batch of 32 test strings:
Encoding speed for 1000 batches of 32 test strings:
Checklist
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example