diff --git a/tests/torchtune/models/llama2/scripts/compare_fused_attention.py b/tests/torchtune/models/llama2/scripts/compare_fused_attention.py index 328d1c528f..0c6c3e938a 100644 --- a/tests/torchtune/models/llama2/scripts/compare_fused_attention.py +++ b/tests/torchtune/models/llama2/scripts/compare_fused_attention.py @@ -256,7 +256,6 @@ def compare_attn( max_seq_len: int, use_kv_cache: bool, ): - torch.manual_seed(16) inputs = torch.randn(4, 2048, 4096) @@ -269,8 +268,9 @@ def compare_attn( kv_cache = KVCache( batch_size=4, max_seq_len=max_seq_len, - n_kv_heads=num_heads, + num_kv_heads=num_kv_heads, head_dim=head_dim, + dtype=inputs.dtype, ) else: kv_cache = None @@ -330,7 +330,6 @@ def compare_attn( if __name__ == "__main__": - # compare mha mha = { "num_heads": 32, diff --git a/tests/torchtune/models/llama2/scripts/compare_lora_attention.py b/tests/torchtune/models/llama2/scripts/compare_lora_attention.py index c6073297da..fb70c5b464 100644 --- a/tests/torchtune/models/llama2/scripts/compare_lora_attention.py +++ b/tests/torchtune/models/llama2/scripts/compare_lora_attention.py @@ -33,7 +33,6 @@ def compare_lora_attention( lora_rank: int, lora_alpha: float, ) -> None: - # make sure we have the right seed for generating outputs # this should match up the seed value set in the corresponding # unit test @@ -68,8 +67,9 @@ def compare_lora_attention( KVCache( batch_size=batch_size, max_seq_len=max_seq_len, - n_kv_heads=num_heads, + num_kv_heads=num_kv_heads, head_dim=head_dim, + dtype=x.dtype, ) if batch_size is not None else None diff --git a/tests/torchtune/modules/test_attention.py b/tests/torchtune/modules/test_attention.py index 872f6684de..0d9dcb5434 100644 --- a/tests/torchtune/modules/test_attention.py +++ b/tests/torchtune/modules/test_attention.py @@ -123,7 +123,7 @@ def gqa_kv_cache( kv_cache = KVCache( batch_size=4, max_seq_len=max_seq_len, - num_heads=num_heads, + num_kv_heads=num_kv_heads, head_dim=head_dim, dtype=torch.float32, ) @@ -178,7 +178,7 @@ def mha_kv_cache( kv_cache = KVCache( batch_size=4, max_seq_len=max_seq_len, - num_heads=num_heads, + num_kv_heads=num_kv_heads, head_dim=head_dim, dtype=torch.float32, ) @@ -233,7 +233,7 @@ def mqa_kv_cache( kv_cache = KVCache( batch_size=4, max_seq_len=max_seq_len, - num_heads=num_heads, + num_kv_heads=num_kv_heads, head_dim=head_dim, dtype=torch.float32, ) @@ -267,7 +267,6 @@ def test_forward_gqa(self, input: torch.Tensor, gqa: MultiHeadAttention) -> None def test_forward_gqa_kv_cache( self, input: torch.Tensor, gqa_kv_cache: MultiHeadAttention, attn_params_gqa ) -> None: - _, _, _, max_seq_len = attn_params_gqa _, seq_len, _ = input.shape @@ -293,7 +292,6 @@ def test_forward_mha(self, input: torch.Tensor, mha: MultiHeadAttention) -> None def test_forward_mha_kv_cache( self, input: torch.Tensor, mha_kv_cache: MultiHeadAttention, attn_params_mha ) -> None: - _, _, _, max_seq_len = attn_params_mha _, seq_len, _ = input.shape diff --git a/torchtune/models/gemma2/_attention.py b/torchtune/models/gemma2/_attention.py index b00612d032..1b7bf38447 100644 --- a/torchtune/models/gemma2/_attention.py +++ b/torchtune/models/gemma2/_attention.py @@ -149,7 +149,7 @@ def setup_cache( self.kv_cache = KVCache( batch_size=batch_size, max_seq_len=max_seq_len, - num_heads=self.num_heads, + num_kv_heads=self.num_heads, head_dim=self.head_dim, dtype=dtype, ) @@ -211,9 +211,9 @@ def forward( - h_d: head dim """ # until flex attention implementation exists, we do not accept block masks - if (mask is not None) and (type(mask) != torch.Tensor()): + if mask is not None and (not isinstance(mask, torch.Tensor)): raise NotImplementedError( - "Block masks are not implemeted yet, use packed=False" + "Block masks are not implemeted yet, use packed=False." ) # x has shape [b, s_x, d] diff --git a/torchtune/modules/attention.py b/torchtune/modules/attention.py index 879f0679cf..b74c70113e 100644 --- a/torchtune/modules/attention.py +++ b/torchtune/modules/attention.py @@ -164,7 +164,7 @@ def setup_cache( self.kv_cache = KVCache( batch_size=batch_size, max_seq_len=max_seq_len, - num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, head_dim=self.head_dim, dtype=dtype, ) @@ -258,47 +258,37 @@ def forward( else: # Update k and v shape, positional embeddings, and normalization - # k has shape [b, s_y, num_kv_heads * head_dim] - # v has shape [b, s_y, num_kv_heads * head_dim] + # k,v shape [b, s_y, num_kv_heads * head_dim] k = self.k_proj(y) v = self.v_proj(y) # Apply positional embeddings - # k: [b, s_y, n_kv, h_d] + # k,v shape: [b, s_y, n_kv, h_d] k = k.view(b, s_y, -1, self.head_dim) + v = v.view(b, s_y, -1, self.head_dim) if self.pos_embeddings is not None: k = self.pos_embeddings(k, input_pos=input_pos) - # View + expand + reshape bring num_kv_heads to num_heads for k and v - # to match q. + # k,v shape: [b, n_kv, s_y, h_d] + k = k.transpose(1, 2) + v = v.transpose(1, 2) - # k: [b, s_y, n_kv, 1, h_d] - # v: [b, s_y, n_kv, 1, h_d] - k = k.view(b, s_y, self.num_kv_heads, 1, self.head_dim) - v = v.view(b, s_y, self.num_kv_heads, 1, self.head_dim) + # Update key-value cache + if self.kv_cache is not None and self.cache_enabled: + k, v = self.kv_cache.update(k, v) # If needed, expand the key and value tensors to have the same shape # as the query tensor by copying values across the relevant dim + # k,v shape: [b, n_h, s, h_d] if self.num_heads != self.num_kv_heads: - k = k.expand(b, s_y, self.num_kv_heads, q_per_kv, self.head_dim) - v = v.expand(b, s_y, self.num_kv_heads, q_per_kv, self.head_dim) - - # [b, s, n_h, h_d] - k = k.reshape(b, s_y, -1, self.head_dim) - v = v.reshape(b, s_y, -1, self.head_dim) - - # [b, n_h, s, h_d] - k = k.transpose(1, 2) - v = v.transpose(1, 2) + expand_shape = (-1, -1, q_per_kv, -1, -1) + k = k.unsqueeze(2).expand(expand_shape).flatten(1, 2) + v = v.unsqueeze(2).expand(expand_shape).flatten(1, 2) # Normalize k if self.k_norm is not None: k = self.k_norm(k) - # Update key-value cache - if self.kv_cache is not None and self.cache_enabled: - k, v = self.kv_cache.update(k, v) - output = self._attention_call( q, k, diff --git a/torchtune/modules/kv_cache.py b/torchtune/modules/kv_cache.py index facd9703ca..e96491c22a 100644 --- a/torchtune/modules/kv_cache.py +++ b/torchtune/modules/kv_cache.py @@ -17,9 +17,7 @@ class KVCache(nn.Module): Args: batch_size (int): batch size model will be run with max_seq_len (int): maximum sequence length model will be run with - num_heads (int): number of heads. We take num_heads instead of num_kv_heads because - the cache is created after we've expanded the key and value tensors to have the - same shape as the query tensor. See attention.py for more details + num_kv_heads (int): number of key/value heads. head_dim (int): per-attention head embedding dimension dtype (torch.dtype): dtype for the caches """ @@ -28,12 +26,12 @@ def __init__( self, batch_size: int, max_seq_len: int, - num_heads: int, + num_kv_heads: int, head_dim: int, dtype: torch.dtype, ) -> None: super().__init__() - cache_shape = (batch_size, num_heads, max_seq_len, head_dim) + cache_shape = (batch_size, num_kv_heads, max_seq_len, head_dim) self.register_buffer( "k_cache", torch.zeros(cache_shape, dtype=dtype), persistent=False ) @@ -66,7 +64,7 @@ def update( already been filled, use ``.reset()``, which will reset the cache to the zero-th position. Example: - >>> cache = KVCache(batch_size=2, max_seq_len=16, num_heads=4, head_dim=32, dtype=torch.bfloat16) + >>> cache = KVCache(batch_size=2, max_seq_len=16, num_kv_heads=4, head_dim=32, dtype=torch.bfloat16) >>> keys, values = torch.ones((2, 4, 8, 32)), torch.ones((2, 4, 8, 32)) >>> cache.update(keys, values) >>> # now positions 0 through 7 are filled