Skip to content

Commit

Permalink
Update gemma-2 kvcache constructor and fix mask type check.
Browse files Browse the repository at this point in the history
  • Loading branch information
mirceamironenco committed Nov 9, 2024
1 parent 39b9801 commit 5f0d2b7
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions torchtune/models/gemma2/_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def setup_cache(
self.kv_cache = KVCache(
batch_size=batch_size,
max_seq_len=max_seq_len,
num_heads=self.num_heads,
num_kv_heads=self.num_heads,
head_dim=self.head_dim,
dtype=dtype,
)
Expand Down Expand Up @@ -211,9 +211,9 @@ def forward(
- h_d: head dim
"""
# until flex attention implementation exists, we do not accept block masks
if (mask is not None) and (type(mask) != torch.Tensor()):
if mask is not None and (not isinstance(mask, torch.Tensor)):
raise NotImplementedError(
"Block masks are not implemeted yet, use packed=False"
"Block masks are not implemeted yet, use packed=False."
)

# x has shape [b, s_x, d]
Expand Down

0 comments on commit 5f0d2b7

Please sign in to comment.