Skip to content

Commit

Permalink
cleanup nits
Browse files Browse the repository at this point in the history
  • Loading branch information
mirceamironenco committed Nov 8, 2024
1 parent c06bd7f commit 65f7498
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 2 deletions.
3 changes: 2 additions & 1 deletion torchtune/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,8 @@ def forward(
k = self.pos_embeddings(k, input_pos=input_pos)

# k,v shape: [b, n_kv, s_y, h_d]
k, v = k.transpose(1, 2), v.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

# Update key-value cache
if self.kv_cache is not None and self.cache_enabled:
Expand Down
1 change: 1 addition & 0 deletions torchtune/modules/attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def _attention_call(
dropout_p: float,
is_causal: bool,
) -> torch.Tensor:

# Flex attention uses the BlockMask
# (https://github.com/pytorch/pytorch/blob/main/torch/nn/attention/flex_attention.py#L168)
# instead of a traditional boolean tensor mask. If this is passed in,
Expand Down
2 changes: 1 addition & 1 deletion torchtune/modules/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def reset(self) -> None:

@property
def size(self) -> int:
return int(self.cache_pos[0].item())
return self.cache_pos[0].item()

def update(
self, k_val: torch.Tensor, v_val: torch.Tensor
Expand Down

0 comments on commit 65f7498

Please sign in to comment.