Skip to content

Commit

Permalink
Update KV Cache to use num_kv_heads instead of num_heads (#1961)
Browse files Browse the repository at this point in the history
  • Loading branch information
mirceamironenco authored Nov 10, 2024
1 parent 08efaed commit e1caa9f
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,6 @@ def compare_attn(
max_seq_len: int,
use_kv_cache: bool,
):

torch.manual_seed(16)
inputs = torch.randn(4, 2048, 4096)

Expand All @@ -269,8 +268,9 @@ def compare_attn(
kv_cache = KVCache(
batch_size=4,
max_seq_len=max_seq_len,
n_kv_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
dtype=inputs.dtype,
)
else:
kv_cache = None
Expand Down Expand Up @@ -330,7 +330,6 @@ def compare_attn(


if __name__ == "__main__":

# compare mha
mha = {
"num_heads": 32,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def compare_lora_attention(
lora_rank: int,
lora_alpha: float,
) -> None:

# make sure we have the right seed for generating outputs
# this should match up the seed value set in the corresponding
# unit test
Expand Down Expand Up @@ -68,8 +67,9 @@ def compare_lora_attention(
KVCache(
batch_size=batch_size,
max_seq_len=max_seq_len,
n_kv_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
dtype=x.dtype,
)
if batch_size is not None
else None
Expand Down
8 changes: 3 additions & 5 deletions tests/torchtune/modules/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def gqa_kv_cache(
kv_cache = KVCache(
batch_size=4,
max_seq_len=max_seq_len,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
dtype=torch.float32,
)
Expand Down Expand Up @@ -178,7 +178,7 @@ def mha_kv_cache(
kv_cache = KVCache(
batch_size=4,
max_seq_len=max_seq_len,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
dtype=torch.float32,
)
Expand Down Expand Up @@ -233,7 +233,7 @@ def mqa_kv_cache(
kv_cache = KVCache(
batch_size=4,
max_seq_len=max_seq_len,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
dtype=torch.float32,
)
Expand Down Expand Up @@ -267,7 +267,6 @@ def test_forward_gqa(self, input: torch.Tensor, gqa: MultiHeadAttention) -> None
def test_forward_gqa_kv_cache(
self, input: torch.Tensor, gqa_kv_cache: MultiHeadAttention, attn_params_gqa
) -> None:

_, _, _, max_seq_len = attn_params_gqa
_, seq_len, _ = input.shape

Expand All @@ -293,7 +292,6 @@ def test_forward_mha(self, input: torch.Tensor, mha: MultiHeadAttention) -> None
def test_forward_mha_kv_cache(
self, input: torch.Tensor, mha_kv_cache: MultiHeadAttention, attn_params_mha
) -> None:

_, _, _, max_seq_len = attn_params_mha
_, seq_len, _ = input.shape

Expand Down
6 changes: 3 additions & 3 deletions torchtune/models/gemma2/_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def setup_cache(
self.kv_cache = KVCache(
batch_size=batch_size,
max_seq_len=max_seq_len,
num_heads=self.num_heads,
num_kv_heads=self.num_heads,
head_dim=self.head_dim,
dtype=dtype,
)
Expand Down Expand Up @@ -211,9 +211,9 @@ def forward(
- h_d: head dim
"""
# until flex attention implementation exists, we do not accept block masks
if (mask is not None) and (type(mask) != torch.Tensor()):
if mask is not None and (not isinstance(mask, torch.Tensor)):
raise NotImplementedError(
"Block masks are not implemeted yet, use packed=False"
"Block masks are not implemeted yet, use packed=False."
)

# x has shape [b, s_x, d]
Expand Down
38 changes: 14 additions & 24 deletions torchtune/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def setup_cache(
self.kv_cache = KVCache(
batch_size=batch_size,
max_seq_len=max_seq_len,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
head_dim=self.head_dim,
dtype=dtype,
)
Expand Down Expand Up @@ -258,47 +258,37 @@ def forward(
else:
# Update k and v shape, positional embeddings, and normalization

# k has shape [b, s_y, num_kv_heads * head_dim]
# v has shape [b, s_y, num_kv_heads * head_dim]
# k,v shape [b, s_y, num_kv_heads * head_dim]
k = self.k_proj(y)
v = self.v_proj(y)

# Apply positional embeddings
# k: [b, s_y, n_kv, h_d]
# k,v shape: [b, s_y, n_kv, h_d]
k = k.view(b, s_y, -1, self.head_dim)
v = v.view(b, s_y, -1, self.head_dim)
if self.pos_embeddings is not None:
k = self.pos_embeddings(k, input_pos=input_pos)

# View + expand + reshape bring num_kv_heads to num_heads for k and v
# to match q.
# k,v shape: [b, n_kv, s_y, h_d]
k = k.transpose(1, 2)
v = v.transpose(1, 2)

# k: [b, s_y, n_kv, 1, h_d]
# v: [b, s_y, n_kv, 1, h_d]
k = k.view(b, s_y, self.num_kv_heads, 1, self.head_dim)
v = v.view(b, s_y, self.num_kv_heads, 1, self.head_dim)
# Update key-value cache
if self.kv_cache is not None and self.cache_enabled:
k, v = self.kv_cache.update(k, v)

# If needed, expand the key and value tensors to have the same shape
# as the query tensor by copying values across the relevant dim
# k,v shape: [b, n_h, s, h_d]
if self.num_heads != self.num_kv_heads:
k = k.expand(b, s_y, self.num_kv_heads, q_per_kv, self.head_dim)
v = v.expand(b, s_y, self.num_kv_heads, q_per_kv, self.head_dim)

# [b, s, n_h, h_d]
k = k.reshape(b, s_y, -1, self.head_dim)
v = v.reshape(b, s_y, -1, self.head_dim)

# [b, n_h, s, h_d]
k = k.transpose(1, 2)
v = v.transpose(1, 2)
expand_shape = (-1, -1, q_per_kv, -1, -1)
k = k.unsqueeze(2).expand(expand_shape).flatten(1, 2)
v = v.unsqueeze(2).expand(expand_shape).flatten(1, 2)

# Normalize k
if self.k_norm is not None:
k = self.k_norm(k)

# Update key-value cache
if self.kv_cache is not None and self.cache_enabled:
k, v = self.kv_cache.update(k, v)

output = self._attention_call(
q,
k,
Expand Down
10 changes: 4 additions & 6 deletions torchtune/modules/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@ class KVCache(nn.Module):
Args:
batch_size (int): batch size model will be run with
max_seq_len (int): maximum sequence length model will be run with
num_heads (int): number of heads. We take num_heads instead of num_kv_heads because
the cache is created after we've expanded the key and value tensors to have the
same shape as the query tensor. See attention.py for more details
num_kv_heads (int): number of key/value heads.
head_dim (int): per-attention head embedding dimension
dtype (torch.dtype): dtype for the caches
"""
Expand All @@ -28,12 +26,12 @@ def __init__(
self,
batch_size: int,
max_seq_len: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
dtype: torch.dtype,
) -> None:
super().__init__()
cache_shape = (batch_size, num_heads, max_seq_len, head_dim)
cache_shape = (batch_size, num_kv_heads, max_seq_len, head_dim)
self.register_buffer(
"k_cache", torch.zeros(cache_shape, dtype=dtype), persistent=False
)
Expand Down Expand Up @@ -66,7 +64,7 @@ def update(
already been filled, use ``.reset()``, which will reset the cache to the zero-th position.
Example:
>>> cache = KVCache(batch_size=2, max_seq_len=16, num_heads=4, head_dim=32, dtype=torch.bfloat16)
>>> cache = KVCache(batch_size=2, max_seq_len=16, num_kv_heads=4, head_dim=32, dtype=torch.bfloat16)
>>> keys, values = torch.ones((2, 4, 8, 32)), torch.ones((2, 4, 8, 32))
>>> cache.update(keys, values)
>>> # now positions 0 through 7 are filled
Expand Down

0 comments on commit e1caa9f

Please sign in to comment.