diff --git a/torchtune/modules/kv_cache.py b/torchtune/modules/kv_cache.py index 6d6eac3266..be05a88433 100644 --- a/torchtune/modules/kv_cache.py +++ b/torchtune/modules/kv_cache.py @@ -17,7 +17,7 @@ class KVCache(nn.Module): Args: batch_size (int): batch size model will be run with max_seq_len (int): maximum sequence length model will be run with - num_kv_heads (int): number key/value heads. + num_kv_heads (int): number of key/value heads. head_dim (int): per-attention head embedding dimension dtype (torch.dtype): dtype for the caches """