Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding MM eval tests / attention bugfixes #1989

Merged
merged 4 commits into from
Nov 13, 2024

Conversation

SalmanMohammadi
Copy link
Collaborator

@SalmanMohammadi SalmanMohammadi commented Nov 12, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses.

Changelog

What are the changes made in this PR?

  • Adds eleuther eval tests for dummy MM models. These tests are running on GPU since it's super slow on CPU.
  • Fixes a bug introduced by Update KV Cache to use num_kv_heads instead of num_heads #1961 which missed expanding KV-cache tensors for cross-attention layers.
  • Fixes a bug introduced in the same PR which applied k_norm after KV-cache update, rather than before, which caused issues when retrieving encoder kv-caches which were not normalized.
  • ^^Added a unit test which catches both of the above cases.
  • Fixed a bug which didn't check for caches_enabled when attempting to retrieve kv-caches for encoder inputs

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API
  • I have added an example to docs or docstrings

Copy link

pytorch-bot bot commented Nov 12, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1989

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit d5c3ece with merge base e1caa9f (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 12, 2024
@codecov-commenter
Copy link

codecov-commenter commented Nov 12, 2024

Codecov Report

Attention: Patch coverage is 37.68116% with 43 lines in your changes missing coverage. Please review.

Project coverage is 67.51%. Comparing base (e1caa9f) to head (b1c283d).
Report is 5 commits behind head on main.

Files with missing lines Patch % Lines
tests/recipes/test_eleuther_eval.py 37.25% 32 Missing ⚠️
tests/recipes/utils.py 21.42% 11 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1989      +/-   ##
==========================================
- Coverage   67.61%   67.51%   -0.11%     
==========================================
  Files         318      318              
  Lines       17597    17674      +77     
==========================================
+ Hits        11899    11933      +34     
- Misses       5698     5741      +43     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

# If needed, expand the key and value tensors to have the same shape
# as the query tensor by copying values across the relevant dim
# k,v shape: [b, n_kv, s, h_d] -> [b, n_h, s, h_d]
if self.num_heads != self.num_kv_heads:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a test that can catch the error you noticed before you made this change?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests in this PR will catch it (it's how I found it) since we perform generation with a vision model with KV-caches enabled. The most atomic test that should catch it is setting up a cross attention layer with KV-caches and kv_heads < num_heads somewhere like here

# as the query tensor by copying values across the relevant dim
# k,v shape: [b, n_kv, s, h_d] -> [b, n_h, s, h_d]
if self.num_heads != self.num_kv_heads:
expand_shape = (-1, -1, q_per_kv, -1, -1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any chance you can use variable names here instead of -1 for explicitness

@@ -781,7 +781,7 @@ def setup_caches(
isinstance(l, TransformerCrossAttentionLayer) for l in self.modules()
)
has_decoder_layers = any(
isinstance(l, TransformerSelfAttentionLayer) for l in self.layers
isinstance(l, TransformerSelfAttentionLayer) for l in self.modules()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:O

@@ -249,7 +249,7 @@ def forward(
q = self.q_norm(q)

if y is None:
if self.kv_cache is None:
if self.kv_cache is None or not self.cache_enabled:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should have been added in #1763

@mps_ignored_test()
def test_forward(
self,
input: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
transformer_layer: TransformerSelfAttentionLayer,
transformer_layer: TransformerCrossAttentionLayer,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

idek why we're typing tests

@SalmanMohammadi SalmanMohammadi changed the title Adding MM eval tests Adding MM eval tests / attention bugfixes Nov 12, 2024
Copy link
Contributor

@RdoubleA RdoubleA left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to go! Thanks for moving the k norm back to the original place

@SalmanMohammadi SalmanMohammadi merged commit 18d97f0 into pytorch:main Nov 13, 2024
17 checks passed
@SalmanMohammadi SalmanMohammadi deleted the actual_mm_tests branch November 13, 2024 08:48
@SalmanMohammadi SalmanMohammadi mentioned this pull request Nov 14, 2024
17 tasks
@ebsmothers ebsmothers mentioned this pull request Nov 26, 2024
44 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants