-
Notifications
You must be signed in to change notification settings - Fork 493
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fixing recompiles in KV-cache + compile (#1663)
- Loading branch information
1 parent
fc0249d
commit bae4b27
Showing
2 changed files
with
50 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
bae4b27
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @SalmanMohammadi, do you have any suggestions for where I can learn more about why this PR fixes recompiles? My guess is that calling
cache_pos = torch.arange(self.size, self.size + seq_len, device=k_val.device)
moves a new tensor to the device which disrupts compilation, but I would like to learn more!bae4b27
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @psoulos! Did you check out the PR description? There's a high-level summary of the issue and solution there #1663.
bae4b27
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @SalmanMohammadi! The PR description and discussion are very helpful. I have a more specific question about this specific line and why it causes an issue:
cache_pos = torch.arange(self.size, self.size + seq_len, device=k_val.device)
On the first pass where you cache the prompt and the second pass where you cache the first generated token, there is a size mismatch and the recompilation is expected. Why is there a recompilation between the second and third passes? The arange changes by an index of 1, but in both cases the resulting
cache_pos
should have a tensor size of [1].Thanks for helping me understand compilation better!
bae4b27
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excellent question. Feel free to ping me on the torchtune Discord btw if you'd ever like to chat about this more!
I'll try explain as best I can - most of my compile knowledge is pretty hacky and I kind of stumbled upon this solution.
So you're absolutely right here and this was my intuition as well. The part of the torch.compile stack that is responsible for tracing the code and generating intermediate graph representations is torch dynamo. Now, dynamo will use
FakeTensors
to trace the graph and guard on properties of the tensor such as size, dtype, etc. You're right about the first recompilation where dynamo will replace the size of the cache with aSymInt
.The next recompiles happen not on the tensor, but on
KVCache.size
, the property of the object itself.You can see that we're actually setting up guards on each specific
KVCache.size
, not on thecache_pos
tensor. I think what's going on here is that dynamo needs to resolveKVCache.size
to actually infer the size ofcache_pos
, and generally torch compile isn't great at any kind of data-dependent dynamism - it seems to set up guards on every concrete value ofself.size
and re-generate graphs each time.EDIT: The above is specifically related to tensors with a size of 1, I think (see https://docs.google.com/document/d/1y5CRfMLdwEoF1nTk9q8qEu1mgMUuUtvhklPKJ2emLU8/edit?pli=1&tab=t.0#heading=h.usdkciy5xgk0)
Sorry that my answer isn't super put-together (and probably not even accurate in many places!), but this is my line of thought here.