-
Notifications
You must be signed in to change notification settings - Fork 471
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update KV Cache to use num_kv_heads instead of num_heads #1961
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1961
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 5f0d2b7 with merge base 08efaed (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
88334e6
to
c06bd7f
Compare
# [b, n_h, s, h_d] | ||
k = k.transpose(1, 2) | ||
v = v.transpose(1, 2) | ||
expand_shape = (-1, -1, q_per_kv, -1, -1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like this
Addressed these and also updated the generate benchmark for single device. |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #1961 +/- ##
==========================================
- Coverage 68.40% 67.26% -1.14%
==========================================
Files 311 316 +5
Lines 16973 17342 +369
==========================================
+ Hits 11610 11665 +55
- Misses 5363 5677 +314 ☔ View full report in Codecov by Sentry. |
65f7498
to
5f0d2b7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much for this, and for your patience in helping test it.
Context
What is the purpose of this PR? Is it to
Please link to any issues this PR addresses.
Changelog
What are the changes made in this PR?
Test plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example
If one changes
generation.yaml
to use llama 3.1 8b (num_kv_heads < num_heads), and the prompt/max_new_toks:I get with the new kv cache (on RTX A6000):
Old kv cache:
I'm not sure how relevant the tok/sec change is given that this is not compiled and batch_size=1.