diff --git a/torchtune/models/gemma2/_attention.py b/torchtune/models/gemma2/_attention.py index b00612d032..1b7bf38447 100644 --- a/torchtune/models/gemma2/_attention.py +++ b/torchtune/models/gemma2/_attention.py @@ -149,7 +149,7 @@ def setup_cache( self.kv_cache = KVCache( batch_size=batch_size, max_seq_len=max_seq_len, - num_heads=self.num_heads, + num_kv_heads=self.num_heads, head_dim=self.head_dim, dtype=dtype, ) @@ -211,9 +211,9 @@ def forward( - h_d: head dim """ # until flex attention implementation exists, we do not accept block masks - if (mask is not None) and (type(mask) != torch.Tensor()): + if mask is not None and (not isinstance(mask, torch.Tensor)): raise NotImplementedError( - "Block masks are not implemeted yet, use packed=False" + "Block masks are not implemeted yet, use packed=False." ) # x has shape [b, s_x, d]