Skip to content

Commit

Permalink
simplify expand for num_kv_heads
Browse files Browse the repository at this point in the history
  • Loading branch information
mirceamironenco committed Nov 8, 2024
1 parent 8654736 commit c06bd7f
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 19 deletions.
11 changes: 4 additions & 7 deletions torchtune/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,7 @@

import torch
from torch import nn
from torchtune.modules.attention_utils import (
_MaskType,
_sdpa_or_flex_attention,
repeat_interleave,
)
from torchtune.modules.attention_utils import _MaskType, _sdpa_or_flex_attention
from torchtune.modules.kv_cache import KVCache

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -284,8 +280,9 @@ def forward(
# as the query tensor by copying values across the relevant dim
# k,v shape: [b, n_h, s, h_d]
if self.num_heads != self.num_kv_heads:
k = repeat_interleave(k, dim=1, repeat=q_per_kv)
v = repeat_interleave(v, dim=1, repeat=q_per_kv)
expand_shape = (-1, -1, q_per_kv, -1, -1)
k = k.unsqueeze(2).expand(expand_shape).flatten(1, 2)
v = v.unsqueeze(2).expand(expand_shape).flatten(1, 2)

# Normalize k
if self.k_norm is not None:
Expand Down
12 changes: 0 additions & 12 deletions torchtune/modules/attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,15 +246,3 @@ def _attention_call(
)

return _attention_call


def repeat_interleave(x: torch.Tensor, *, dim: int, repeat: int) -> torch.Tensor:
if repeat == 1:
return x

dim = dim + x.ndim if dim < 0 else dim

shape = [-1] * (x.ndim + 1)
shape[dim + 1] = repeat

return x.unsqueeze(dim + 1).expand(shape).flatten(dim, dim + 1)

0 comments on commit c06bd7f

Please sign in to comment.