Skip to content

Commit

Permalink
Fixing recompiles in KV-cache + compile (#1663)
Browse files Browse the repository at this point in the history
  • Loading branch information
SalmanMohammadi authored Oct 2, 2024
1 parent fc0249d commit bae4b27
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 16 deletions.
32 changes: 29 additions & 3 deletions tests/torchtune/modules/test_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import pytest
import torch
import torch._dynamo.testing
from torchtune.modules import KVCache

BSZ = 2
Expand Down Expand Up @@ -52,6 +53,7 @@ def test_kv_cache_init(self, kv_cache):
def test_kv_cache_reset(self, kv_cache, k_vals_full, v_vals_full):
kv_cache.update(k_vals_full, v_vals_full)
kv_cache.reset()

assert (kv_cache.k_cache == 0).all() and (kv_cache.v_cache == 0).all()
assert kv_cache.size == 0

Expand All @@ -62,7 +64,7 @@ def test_kv_cache_error_when_bsz_exceeded(self, kv_cache, k_vals_full, v_vals_fu
def test_kv_cache_error_when_seq_len_exceeded(
self, kv_cache, k_vals_full, v_vals_full
):
with pytest.raises(ValueError):
with pytest.raises(AssertionError):
kv_cache.update(k_vals_full.repeat(1, 1, 4, 1), v_vals_full)

def test_kv_cache_error_when_seq_len_exceeded_after_update(
Expand All @@ -75,8 +77,7 @@ def test_kv_cache_error_when_seq_len_exceeded_after_update(
v_vals_full[:, :, : (MAX_SEQ_LEN // 2)],
)
with pytest.raises(
ValueError,
match=f"cache has reached a sequence length of {MAX_SEQ_LEN + MAX_SEQ_LEN // 2}",
AssertionError,
):
# now an invalid update exceeding the cache
kv_cache.update(k_vals_full, v_vals_full)
Expand Down Expand Up @@ -151,3 +152,28 @@ def test_kv_cache_multiple_updates(self, kv_cache, k_vals_full, v_vals_full):

assert torch.equal(expected_k_out, k_out)
assert torch.equal(expected_v_out, v_out)

def test_kv_cache_no_recompiles(self, kv_cache, k_vals_full, v_vals_full):
def fn(k_val, v_val):
return kv_cache.update(k_val, v_val)

cnts = torch._dynamo.testing.CompileCounter()
# this effectively does torch.compile(fn)
fn = torch._dynamo.optimize(cnts, nopython=True)(fn)

# make an update filling half the cache - like a prefill
# fills position 0 through to (MAX_SEQ_LEN // 2) - 1
kv_cache.update(
k_vals_full[:, :, : (MAX_SEQ_LEN // 2)],
v_vals_full[:, :, : (MAX_SEQ_LEN // 2)],
)

# now make successive updates for one token position at a time
# and ensure there are no recompiles
for i in range(MAX_SEQ_LEN // 2):
fn(
k_vals_full[:, :, (MAX_SEQ_LEN // 2) + i].unsqueeze(2),
v_vals_full[:, :, (MAX_SEQ_LEN // 2) + i].unsqueeze(2),
)

assert cnts.frame_count == 1
34 changes: 21 additions & 13 deletions torchtune/modules/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,20 @@ def __init__(
self.register_buffer(
"v_cache", torch.zeros(cache_shape, dtype=dtype), persistent=False
)
self.size = 0
self.register_buffer(
"cache_pos", torch.arange(0, cache_shape[2]), persistent=False
)
self.batch_size = batch_size

def reset(self) -> None:
"""Reset the cache to zero."""
self.k_cache.zero_()
self.v_cache.zero_()
self.size = 0
self.cache_pos -= self.size

@property
def size(self) -> int:
return self.cache_pos[0].item()

def update(
self, k_val: torch.Tensor, v_val: torch.Tensor
Expand Down Expand Up @@ -80,7 +86,7 @@ def update(
Tuple[torch.Tensor, torch.Tensor]: Updated key and value cache tensors, respectively.
Raises:
ValueError: if the sequence length of ``k_val`` is longer than the maximum cache sequence length.
AssertionError: if the sequence length of ``k_val`` is longer than the maximum cache sequence length.
ValueError: if the batch size of the new key (or value) tensor is greater than the batch size
used during cache setup.
"""
Expand All @@ -91,18 +97,20 @@ def update(
f", but found new key tensors with batch size {k_val.shape[0]}!"
)

if (self.size + seq_len) > self.k_cache.shape[2]:
raise ValueError(
f"The current cache has been setup with a sequence length of {self.k_cache.shape[2]}"
f", but the cache has reached a sequence length of {(self.size + seq_len)}!"
)
cache_pos = torch.arange(self.size, self.size + seq_len, device=k_val.device)
self.size += seq_len

assert (self.cache_pos[0] + seq_len) <= self.k_cache.shape[2]
k_out = self.k_cache
v_out = self.v_cache

k_out[:, :, cache_pos] = k_val
v_out[:, :, cache_pos] = v_val
k_out[:, :, self.cache_pos[:seq_len]] = k_val
v_out[:, :, self.cache_pos[:seq_len]] = v_val

# forward cache_pos seq_len positions along
# cache_pos starts at (0, 1, 2, 3, 4, 5, ...)
# an update of seq_len = 5 tokens brings it to
# (5, 6, 7, 8, 9, ...)
# this allows us to track the current position in the cache
# after the last update in a compile-friendly way without any dynamism
# e.g. relying on an int size tracker, or re-creating cache_pos every time
self.cache_pos += seq_len

return k_out, v_out

4 comments on commit bae4b27

@psoulos
Copy link
Contributor

@psoulos psoulos commented on bae4b27 Oct 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @SalmanMohammadi, do you have any suggestions for where I can learn more about why this PR fixes recompiles? My guess is that calling cache_pos = torch.arange(self.size, self.size + seq_len, device=k_val.device) moves a new tensor to the device which disrupts compilation, but I would like to learn more!

@SalmanMohammadi
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @psoulos! Did you check out the PR description? There's a high-level summary of the issue and solution there #1663.

@psoulos
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @SalmanMohammadi! The PR description and discussion are very helpful. I have a more specific question about this specific line and why it causes an issue:

cache_pos = torch.arange(self.size, self.size + seq_len, device=k_val.device)

On the first pass where you cache the prompt and the second pass where you cache the first generated token, there is a size mismatch and the recompilation is expected. Why is there a recompilation between the second and third passes? The arange changes by an index of 1, but in both cases the resulting cache_pos should have a tensor size of [1].

Thanks for helping me understand compilation better!

@SalmanMohammadi
Copy link
Collaborator Author

@SalmanMohammadi SalmanMohammadi commented on bae4b27 Oct 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent question. Feel free to ping me on the torchtune Discord btw if you'd ever like to chat about this more!

I'll try explain as best I can - most of my compile knowledge is pretty hacky and I kind of stumbled upon this solution.

The arange changes by an index of 1, but in both cases the resulting cache_pos should have a tensor size of [1].

So you're absolutely right here and this was my intuition as well. The part of the torch.compile stack that is responsible for tracing the code and generating intermediate graph representations is torch dynamo. Now, dynamo will use FakeTensors to trace the graph and guard on properties of the tensor such as size, dtype, etc. You're right about the first recompilation where dynamo will replace the size of the cache with a SymInt.

The next recompiles happen not on the tensor, but on KVCache.size, the property of the object itself.

V0924 15:11:42.775000 98048 torch/_dynamo/guards.py:2830] [0/1] [__recompiles]     - 0/0: L['args'][1]._modules['layers']._modules['0']._modules['attn']._modules['kv_cache'].size == 4
tensor(-0.0110)
V0924 15:11:45.333000 98048 torch/_dynamo/guards.py:2830] [0/2] [__recompiles]     triggered by the following guard failure(s):
V0924 15:11:45.333000 98048 torch/_dynamo/guards.py:2830] [0/2] [__recompiles]     - 0/1: L['args'][1]._modules['layers']._modules['0']._modules['attn']._modules['kv_cache'].size == 5
V0924 15:11:45.333000 98048 torch/_dynamo/guards.py:2830] [0/2] [__recompiles]     - 0/0: L['args'][1]._modules['layers']._modules['0']._modules['attn']._modules['kv_cache'].size == 4

You can see that we're actually setting up guards on each specific KVCache.size, not on the cache_pos tensor. I think what's going on here is that dynamo needs to resolve KVCache.size to actually infer the size of cache_pos, and generally torch compile isn't great at any kind of data-dependent dynamism - it seems to set up guards on every concrete value of self.size and re-generate graphs each time.

EDIT: The above is specifically related to tensors with a size of 1, I think (see https://docs.google.com/document/d/1y5CRfMLdwEoF1nTk9q8qEu1mgMUuUtvhklPKJ2emLU8/edit?pli=1&tab=t.0#heading=h.usdkciy5xgk0)

Sorry that my answer isn't super put-together (and probably not even accurate in many places!), but this is my line of thought here.

Please sign in to comment.