Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update KV Cache to use num_kv_heads instead of num_heads #1961

Merged
merged 6 commits into from
Nov 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,6 @@ def compare_attn(
max_seq_len: int,
use_kv_cache: bool,
):

torch.manual_seed(16)
inputs = torch.randn(4, 2048, 4096)

Expand All @@ -269,8 +268,9 @@ def compare_attn(
kv_cache = KVCache(
batch_size=4,
max_seq_len=max_seq_len,
n_kv_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
dtype=inputs.dtype,
)
else:
kv_cache = None
Expand Down Expand Up @@ -330,7 +330,6 @@ def compare_attn(


if __name__ == "__main__":

# compare mha
mha = {
"num_heads": 32,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def compare_lora_attention(
lora_rank: int,
lora_alpha: float,
) -> None:

# make sure we have the right seed for generating outputs
# this should match up the seed value set in the corresponding
# unit test
Expand Down Expand Up @@ -68,8 +67,9 @@ def compare_lora_attention(
KVCache(
batch_size=batch_size,
max_seq_len=max_seq_len,
n_kv_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
dtype=x.dtype,
)
if batch_size is not None
else None
Expand Down
8 changes: 3 additions & 5 deletions tests/torchtune/modules/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def gqa_kv_cache(
kv_cache = KVCache(
batch_size=4,
max_seq_len=max_seq_len,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
dtype=torch.float32,
)
Expand Down Expand Up @@ -178,7 +178,7 @@ def mha_kv_cache(
kv_cache = KVCache(
batch_size=4,
max_seq_len=max_seq_len,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
dtype=torch.float32,
)
Expand Down Expand Up @@ -233,7 +233,7 @@ def mqa_kv_cache(
kv_cache = KVCache(
batch_size=4,
max_seq_len=max_seq_len,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
dtype=torch.float32,
)
Expand Down Expand Up @@ -267,7 +267,6 @@ def test_forward_gqa(self, input: torch.Tensor, gqa: MultiHeadAttention) -> None
def test_forward_gqa_kv_cache(
self, input: torch.Tensor, gqa_kv_cache: MultiHeadAttention, attn_params_gqa
) -> None:

_, _, _, max_seq_len = attn_params_gqa
_, seq_len, _ = input.shape

Expand All @@ -293,7 +292,6 @@ def test_forward_mha(self, input: torch.Tensor, mha: MultiHeadAttention) -> None
def test_forward_mha_kv_cache(
self, input: torch.Tensor, mha_kv_cache: MultiHeadAttention, attn_params_mha
) -> None:

_, _, _, max_seq_len = attn_params_mha
_, seq_len, _ = input.shape

Expand Down
6 changes: 3 additions & 3 deletions torchtune/models/gemma2/_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def setup_cache(
self.kv_cache = KVCache(
batch_size=batch_size,
max_seq_len=max_seq_len,
num_heads=self.num_heads,
num_kv_heads=self.num_heads,
head_dim=self.head_dim,
dtype=dtype,
)
Expand Down Expand Up @@ -211,9 +211,9 @@ def forward(
- h_d: head dim
"""
# until flex attention implementation exists, we do not accept block masks
if (mask is not None) and (type(mask) != torch.Tensor()):
if mask is not None and (not isinstance(mask, torch.Tensor)):
raise NotImplementedError(
"Block masks are not implemeted yet, use packed=False"
"Block masks are not implemeted yet, use packed=False."
)

# x has shape [b, s_x, d]
Expand Down
38 changes: 14 additions & 24 deletions torchtune/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def setup_cache(
self.kv_cache = KVCache(
batch_size=batch_size,
max_seq_len=max_seq_len,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
head_dim=self.head_dim,
dtype=dtype,
)
Expand Down Expand Up @@ -258,47 +258,37 @@ def forward(
else:
# Update k and v shape, positional embeddings, and normalization

# k has shape [b, s_y, num_kv_heads * head_dim]
# v has shape [b, s_y, num_kv_heads * head_dim]
# k,v shape [b, s_y, num_kv_heads * head_dim]
k = self.k_proj(y)
v = self.v_proj(y)

# Apply positional embeddings
# k: [b, s_y, n_kv, h_d]
# k,v shape: [b, s_y, n_kv, h_d]
k = k.view(b, s_y, -1, self.head_dim)
v = v.view(b, s_y, -1, self.head_dim)
if self.pos_embeddings is not None:
k = self.pos_embeddings(k, input_pos=input_pos)

# View + expand + reshape bring num_kv_heads to num_heads for k and v
# to match q.
# k,v shape: [b, n_kv, s_y, h_d]
k = k.transpose(1, 2)
v = v.transpose(1, 2)

# k: [b, s_y, n_kv, 1, h_d]
# v: [b, s_y, n_kv, 1, h_d]
k = k.view(b, s_y, self.num_kv_heads, 1, self.head_dim)
v = v.view(b, s_y, self.num_kv_heads, 1, self.head_dim)
# Update key-value cache
if self.kv_cache is not None and self.cache_enabled:
k, v = self.kv_cache.update(k, v)

# If needed, expand the key and value tensors to have the same shape
# as the query tensor by copying values across the relevant dim
# k,v shape: [b, n_h, s, h_d]
if self.num_heads != self.num_kv_heads:
k = k.expand(b, s_y, self.num_kv_heads, q_per_kv, self.head_dim)
v = v.expand(b, s_y, self.num_kv_heads, q_per_kv, self.head_dim)

# [b, s, n_h, h_d]
k = k.reshape(b, s_y, -1, self.head_dim)
v = v.reshape(b, s_y, -1, self.head_dim)

# [b, n_h, s, h_d]
k = k.transpose(1, 2)
v = v.transpose(1, 2)
expand_shape = (-1, -1, q_per_kv, -1, -1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this

k = k.unsqueeze(2).expand(expand_shape).flatten(1, 2)
v = v.unsqueeze(2).expand(expand_shape).flatten(1, 2)

# Normalize k
if self.k_norm is not None:
k = self.k_norm(k)

# Update key-value cache
if self.kv_cache is not None and self.cache_enabled:
k, v = self.kv_cache.update(k, v)

output = self._attention_call(
q,
k,
Expand Down
10 changes: 4 additions & 6 deletions torchtune/modules/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@ class KVCache(nn.Module):
Args:
batch_size (int): batch size model will be run with
max_seq_len (int): maximum sequence length model will be run with
num_heads (int): number of heads. We take num_heads instead of num_kv_heads because
the cache is created after we've expanded the key and value tensors to have the
same shape as the query tensor. See attention.py for more details
num_kv_heads (int): number of key/value heads.
head_dim (int): per-attention head embedding dimension
dtype (torch.dtype): dtype for the caches
"""
Expand All @@ -28,12 +26,12 @@ def __init__(
self,
batch_size: int,
max_seq_len: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
dtype: torch.dtype,
) -> None:
super().__init__()
cache_shape = (batch_size, num_heads, max_seq_len, head_dim)
cache_shape = (batch_size, num_kv_heads, max_seq_len, head_dim)
self.register_buffer(
"k_cache", torch.zeros(cache_shape, dtype=dtype), persistent=False
)
Expand Down Expand Up @@ -66,7 +64,7 @@ def update(
already been filled, use ``.reset()``, which will reset the cache to the zero-th position.

Example:
>>> cache = KVCache(batch_size=2, max_seq_len=16, num_heads=4, head_dim=32, dtype=torch.bfloat16)
>>> cache = KVCache(batch_size=2, max_seq_len=16, num_kv_heads=4, head_dim=32, dtype=torch.bfloat16)
>>> keys, values = torch.ones((2, 4, 8, 32)), torch.ones((2, 4, 8, 32))
>>> cache.update(keys, values)
>>> # now positions 0 through 7 are filled
Expand Down
Loading