Skip to content

Commit

Permalink
Fix attention fwd pass shapes
Browse files Browse the repository at this point in the history
  • Loading branch information
mirceamironenco committed Nov 7, 2024
1 parent 3671260 commit c00aea7
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,6 @@ def compare_attn(
max_seq_len: int,
use_kv_cache: bool,
):

torch.manual_seed(16)
inputs = torch.randn(4, 2048, 4096)

Expand All @@ -269,8 +268,9 @@ def compare_attn(
kv_cache = KVCache(
batch_size=4,
max_seq_len=max_seq_len,
n_kv_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
dtype=inputs.dtype,
)
else:
kv_cache = None
Expand Down Expand Up @@ -330,7 +330,6 @@ def compare_attn(


if __name__ == "__main__":

# compare mha
mha = {
"num_heads": 32,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def compare_lora_attention(
lora_rank: int,
lora_alpha: float,
) -> None:

# make sure we have the right seed for generating outputs
# this should match up the seed value set in the corresponding
# unit test
Expand Down Expand Up @@ -68,8 +67,9 @@ def compare_lora_attention(
KVCache(
batch_size=batch_size,
max_seq_len=max_seq_len,
n_kv_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
dtype=x.dtype,
)
if batch_size is not None
else None
Expand Down
34 changes: 13 additions & 21 deletions torchtune/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@

import torch
from torch import nn
from torchtune.modules.attention_utils import _MaskType, _sdpa_or_flex_attention
from torchtune.modules.attention_utils import (
_MaskType,
_sdpa_or_flex_attention,
repeat_interleave,
)
from torchtune.modules.kv_cache import KVCache

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -258,42 +262,30 @@ def forward(
else:
# Update k and v shape, positional embeddings, and normalization

# k has shape [b, s_y, num_kv_heads * head_dim]
# v has shape [b, s_y, num_kv_heads * head_dim]
# k,v shape [b, s_y, num_kv_heads * head_dim]
k = self.k_proj(y)
v = self.v_proj(y)

# Apply positional embeddings
# k: [b, s_y, n_kv, h_d]
# k,v shape: [b, s_y, n_kv, h_d]
k = k.view(b, s_y, -1, self.head_dim)
v = v.view(b, s_y, -1, self.head_dim)
if self.pos_embeddings is not None:
k = self.pos_embeddings(k, input_pos=input_pos)

# [b, n_h, s, h_d]
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# k,v shape: [b, n_kv, s_y, h_d]
k, v = k.transpose(1, 2), v.transpose(1, 2)

# Update key-value cache
if self.kv_cache is not None and self.cache_enabled:
k, v = self.kv_cache.update(k, v)

# View + expand + reshape bring num_kv_heads to num_heads for k and v
# to match q.

# k: [b, n_kv, 1, s_y, h_d]
# v: [b, n_kv, 1, s_y, h_d]
k = k.view(b, self.num_kv_heads, 1, s_y, self.head_dim)
v = v.view(b, self.num_kv_heads, 1, s_y, self.head_dim)

# If needed, expand the key and value tensors to have the same shape
# as the query tensor by copying values across the relevant dim
# k,v shape: [b, n_h, s, h_d]
if self.num_heads != self.num_kv_heads:
k = k.expand(b, self.num_kv_heads, q_per_kv, s_y, self.head_dim)
v = v.expand(b, self.num_kv_heads, q_per_kv, s_y, self.head_dim)

# [b, s, n_h, h_d]
k = k.reshape(b, -1, s_y, self.head_dim)
v = v.reshape(b, -1, s_y, self.head_dim)
k = repeat_interleave(k, dim=1, repeat=q_per_kv)
v = repeat_interleave(v, dim=1, repeat=q_per_kv)

# Normalize k
if self.k_norm is not None:
Expand Down
13 changes: 12 additions & 1 deletion torchtune/modules/attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,6 @@ def _attention_call(
dropout_p: float,
is_causal: bool,
) -> torch.Tensor:

# Flex attention uses the BlockMask
# (https://github.com/pytorch/pytorch/blob/main/torch/nn/attention/flex_attention.py#L168)
# instead of a traditional boolean tensor mask. If this is passed in,
Expand Down Expand Up @@ -247,3 +246,15 @@ def _attention_call(
)

return _attention_call


def repeat_interleave(x: torch.Tensor, *, dim: int, repeat: int) -> torch.Tensor:
if repeat == 1:
return x

dim = dim + x.ndim if dim < 0 else dim

shape = [-1] * (x.ndim + 1)
shape[dim + 1] = repeat

return x.unsqueeze(dim + 1).expand(shape).flatten(dim, dim + 1)

0 comments on commit c00aea7

Please sign in to comment.