-
Notifications
You must be signed in to change notification settings - Fork 471
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding MM eval tests / attention bugfixes #1989
Adding MM eval tests / attention bugfixes #1989
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1989
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit d5c3ece with merge base e1caa9f (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1989 +/- ##
==========================================
- Coverage 67.61% 67.51% -0.11%
==========================================
Files 318 318
Lines 17597 17674 +77
==========================================
+ Hits 11899 11933 +34
- Misses 5698 5741 +43 ☔ View full report in Codecov by Sentry. |
torchtune/modules/attention.py
Outdated
# If needed, expand the key and value tensors to have the same shape | ||
# as the query tensor by copying values across the relevant dim | ||
# k,v shape: [b, n_kv, s, h_d] -> [b, n_h, s, h_d] | ||
if self.num_heads != self.num_kv_heads: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a test that can catch the error you noticed before you made this change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tests in this PR will catch it (it's how I found it) since we perform generation with a vision model with KV-caches enabled. The most atomic test that should catch it is setting up a cross attention layer with KV-caches and kv_heads < num_heads somewhere like here
torchtune/modules/attention.py
Outdated
# as the query tensor by copying values across the relevant dim | ||
# k,v shape: [b, n_kv, s, h_d] -> [b, n_h, s, h_d] | ||
if self.num_heads != self.num_kv_heads: | ||
expand_shape = (-1, -1, q_per_kv, -1, -1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any chance you can use variable names here instead of -1 for explicitness
@@ -781,7 +781,7 @@ def setup_caches( | |||
isinstance(l, TransformerCrossAttentionLayer) for l in self.modules() | |||
) | |||
has_decoder_layers = any( | |||
isinstance(l, TransformerSelfAttentionLayer) for l in self.layers | |||
isinstance(l, TransformerSelfAttentionLayer) for l in self.modules() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
:O
@@ -249,7 +249,7 @@ def forward( | |||
q = self.q_norm(q) | |||
|
|||
if y is None: | |||
if self.kv_cache is None: | |||
if self.kv_cache is None or not self.cache_enabled: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should have been added in #1763
@mps_ignored_test() | ||
def test_forward( | ||
self, | ||
input: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], | ||
transformer_layer: TransformerSelfAttentionLayer, | ||
transformer_layer: TransformerCrossAttentionLayer, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
idek why we're typing tests
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to go! Thanks for moving the k norm back to the original place
Context
What is the purpose of this PR? Is it to
Please link to any issues this PR addresses.
Changelog
What are the changes made in this PR?
k_norm
after KV-cache update, rather than before, which caused issues when retrieving encoder kv-caches which were not normalized.caches_enabled
when attempting to retrieve kv-caches for encoder inputsTest plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example