Skip to content

Commit

Permalink
Standardize torch.Tensor typing (#1471)
Browse files Browse the repository at this point in the history
  • Loading branch information
SalmanMohammadi authored Sep 2, 2024
1 parent 0b9f830 commit 7044044
Show file tree
Hide file tree
Showing 27 changed files with 182 additions and 170 deletions.
5 changes: 3 additions & 2 deletions docs/source/tutorials/lora_finetune.rst
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ Let's take a look at a minimal implementation of LoRA in native PyTorch.

.. code-block:: python
from torch import nn, Tensor
import torch
from torch import nn
class LoRALinear(nn.Module):
def __init__(
Expand Down Expand Up @@ -114,7 +115,7 @@ Let's take a look at a minimal implementation of LoRA in native PyTorch.
self.lora_a.weight.requires_grad = True
self.lora_b.weight.requires_grad = True
def forward(self, x: Tensor) -> Tensor:
def forward(self, x: torch.Tensor) -> torch.Tensor:
# This would be the output of the original model
frozen_out = self.linear(x)
Expand Down
5 changes: 3 additions & 2 deletions docs/source/tutorials/qlora_finetune.rst
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,8 @@ a vanilla minimal LoRA layer, taken from :ref:`the LoRA tutorial <lora_finetune_
.. code-block:: python
:emphasize-lines: 3, 13, 19, 20, 39, 40, 41
from torch import nn, Tensor
import torch
from torch import nn
import torch.nn.functional as F
from torchao.dtypes.nf4tensor import linear_nf4, to_nf4
Expand Down Expand Up @@ -253,7 +254,7 @@ a vanilla minimal LoRA layer, taken from :ref:`the LoRA tutorial <lora_finetune_
self.lora_a.weight.requires_grad = True
self.lora_b.weight.requires_grad = True
def forward(self, x: Tensor) -> Tensor:
def forward(self, x: torch.Tensor) -> torch.Tensor:
# frozen_out would be the output of the original model
if quantize_base:
# Call into torchao's linear_nf4 to run linear forward pass w/quantized weight.
Expand Down
11 changes: 6 additions & 5 deletions tests/torchtune/models/llama2/scripts/compare_fused_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torch import nn, Tensor
from torchtune.modules import KVCache, MultiHeadAttention, RotaryPositionalEmbeddings


# Copy-paste of fused attention for comparison
class FusedMultiHeadAttention(nn.Module):
"""Multi-headed grouped query self-attention (GQA) layer introduced
Expand Down Expand Up @@ -115,15 +116,15 @@ def __init__(

def forward(
self,
x: Tensor,
mask: Optional[Tensor] = None,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
curr_pos: int = 0,
) -> Tensor:
) -> torch.Tensor:
"""
Args:
x (Tensor): input tensor with shape
[batch_size x seq_length x embed_dim]
mask (Optional[Tensor]): boolean mask, defaults to None.
mask (Optional[torch.Tensor]): boolean mask, defaults to None.
curr_pos (int): current position in the sequence, defaults to 0.
Returns:
Expand Down Expand Up @@ -241,7 +242,7 @@ def map_state_dict(
return mapped_sd


def _get_mask(inpt: Tensor) -> Tensor:
def _get_mask(inpt: torch.Tensor) -> torch.Tensor:
seq_len = inpt.shape[1]
mask = torch.full((1, 1, seq_len, seq_len), float("-inf"), device=inpt.device)
mask = torch.triu(mask, diagonal=1).type_as(inpt)
Expand Down
24 changes: 12 additions & 12 deletions tests/torchtune/modules/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch

from tests.test_utils import assert_expected, fixed_init_model
from torch import nn, Tensor
from torch import nn

from torchtune.modules import KVCache, MultiHeadAttention, RotaryPositionalEmbeddings
from torchtune.utils.seed import set_seed
Expand Down Expand Up @@ -40,7 +40,7 @@ def input_params(self) -> Tuple[int, int, int]:
return batch_size, seq_len, embed_dim

@pytest.fixture
def input(self, input_params: Tuple[int, int, int]) -> Tensor:
def input(self, input_params: Tuple[int, int, int]) -> torch.Tensor:
batch_size, seq_len, embed_dim = input_params
x = torch.randn(batch_size, seq_len, embed_dim)
return x
Expand All @@ -58,7 +58,7 @@ def input_max_len_exceeded(
self,
input_params: Tuple[int, int, int],
attn_params_gqa: Tuple[int, int, int, int],
) -> Tensor:
) -> torch.Tensor:
batch_size, seq_len, embed_dim = input_params
_, _, _, max_seq_len = attn_params_gqa
seq_len = max_seq_len + 1
Expand All @@ -69,7 +69,7 @@ def input_max_bs_exceeded(
self,
input_params: Tuple[int, int, int],
attn_params_gqa: Tuple[int, int, int, int],
) -> Tensor:
) -> torch.Tensor:
batch_size, seq_len, embed_dim = input_params
_, _, _, max_seq_len = attn_params_gqa
batch_size += 1
Expand Down Expand Up @@ -253,7 +253,7 @@ def mqa_kv_cache(
attn.eval()
return attn

def test_forward_gqa(self, input: Tensor, gqa: MultiHeadAttention) -> None:
def test_forward_gqa(self, input: torch.Tensor, gqa: MultiHeadAttention) -> None:
with torch.no_grad():
output = gqa(input)
assert_expected(
Expand All @@ -262,7 +262,7 @@ def test_forward_gqa(self, input: Tensor, gqa: MultiHeadAttention) -> None:
assert_expected(output.shape, input.shape)

def test_forward_gqa_kv_cache(
self, input: Tensor, gqa_kv_cache: MultiHeadAttention, attn_params_gqa
self, input: torch.Tensor, gqa_kv_cache: MultiHeadAttention, attn_params_gqa
) -> None:

_, _, _, max_seq_len = attn_params_gqa
Expand All @@ -279,7 +279,7 @@ def test_forward_gqa_kv_cache(
)
assert_expected(output.shape, input.shape)

def test_forward_mha(self, input: Tensor, mha: MultiHeadAttention) -> None:
def test_forward_mha(self, input: torch.Tensor, mha: MultiHeadAttention) -> None:
with torch.no_grad():
output = mha(input)
assert_expected(
Expand All @@ -288,7 +288,7 @@ def test_forward_mha(self, input: Tensor, mha: MultiHeadAttention) -> None:
assert_expected(output.shape, input.shape)

def test_forward_mha_kv_cache(
self, input: Tensor, mha_kv_cache: MultiHeadAttention, attn_params_mha
self, input: torch.Tensor, mha_kv_cache: MultiHeadAttention, attn_params_mha
) -> None:

_, _, _, max_seq_len = attn_params_mha
Expand All @@ -305,7 +305,7 @@ def test_forward_mha_kv_cache(
)
assert_expected(output.shape, input.shape)

def test_forward_mqa(self, input: Tensor, mqa: MultiHeadAttention) -> None:
def test_forward_mqa(self, input: torch.Tensor, mqa: MultiHeadAttention) -> None:
with torch.no_grad():
output = mqa(input)
assert_expected(
Expand All @@ -314,7 +314,7 @@ def test_forward_mqa(self, input: Tensor, mqa: MultiHeadAttention) -> None:
assert_expected(output.shape, input.shape)

def test_forward_mqa_kv_cache(
self, input: Tensor, mqa_kv_cache: MultiHeadAttention, attn_params_mqa
self, input: torch.Tensor, mqa_kv_cache: MultiHeadAttention, attn_params_mqa
) -> None:
_, _, _, max_seq_len = attn_params_mqa
_, seq_len, _ = input.shape
Expand All @@ -332,15 +332,15 @@ def test_forward_mqa_kv_cache(

def test_max_seq_len_exceeded(
self,
input_max_len_exceeded: Tensor,
input_max_len_exceeded: torch.Tensor,
gqa: MultiHeadAttention,
) -> None:
with pytest.raises(Exception):
_ = gqa(input_max_len_exceeded)

def test_batch_size_exceeded(
self,
input_max_bs_exceeded: Tensor,
input_max_bs_exceeded: torch.Tensor,
gqa_kv_cache: MultiHeadAttention,
) -> None:
with pytest.raises(Exception):
Expand Down
6 changes: 3 additions & 3 deletions tests/torchtune/modules/test_feed_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch

from tests.test_utils import assert_expected, fixed_init_model
from torch import nn, Tensor
from torch import nn

from torchtune.modules import FeedForward
from torchtune.utils.seed import set_seed
Expand All @@ -32,7 +32,7 @@ def input_params(self) -> Tuple[int, int]:
return dim, hidden_dim

@pytest.fixture
def input(self, input_params: Tuple[int, int]) -> Tensor:
def input(self, input_params: Tuple[int, int]) -> torch.Tensor:
dim, _ = input_params
return torch.randn(1, dim)

Expand All @@ -49,7 +49,7 @@ def ffn(self, input_params: Tuple[int, int]) -> FeedForward:
ff.eval()
return ff

def test_forward(self, input: Tensor, ffn: FeedForward) -> None:
def test_forward(self, input: torch.Tensor, ffn: FeedForward) -> None:
with torch.no_grad():
x_out = ffn(input)
assert_expected(x_out.mean(), torch.tensor(251.5356), atol=1e-7, rtol=1e-3)
Expand Down
26 changes: 13 additions & 13 deletions tests/torchtune/modules/test_transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from tests.test_utils import assert_expected

from torch import nn, Tensor
from torch import nn

from torchtune.models.llama2 import llama2
from torchtune.models.llama2._component_builders import llama2_mlp
Expand Down Expand Up @@ -54,7 +54,7 @@ def input_params(self) -> Tuple[int, int, int]:
return batch_size, seq_len, embed_dim

@pytest.fixture
def input(self, input_params: Tuple[int, int, int]) -> Tensor:
def input(self, input_params: Tuple[int, int, int]) -> torch.Tensor:
batch_size, seq_len, embed_dim = input_params
return torch.randn(batch_size, seq_len, embed_dim)

Expand Down Expand Up @@ -100,7 +100,7 @@ def transformer_layer(
return transformer_layer

def test_forward(
self, input: Tensor, transformer_layer: TransformerSelfAttentionLayer
self, input: torch.Tensor, transformer_layer: TransformerSelfAttentionLayer
) -> None:
with torch.no_grad():
output = transformer_layer(input)
Expand All @@ -125,7 +125,7 @@ def input_params(self) -> Tuple[int, int, int, int]:
return batch_size, seq_len, encoder_seq_len, embed_dim

@pytest.fixture
def input(self, input_params: Tuple[int, int, int, int]) -> Tensor:
def input(self, input_params: Tuple[int, int, int, int]) -> torch.Tensor:
batch_size, seq_len, encoder_seq_len, embed_dim = input_params
rand_x = torch.randn(batch_size, seq_len, embed_dim)
rand_y = torch.randn(batch_size, 128, embed_dim)
Expand Down Expand Up @@ -185,7 +185,7 @@ def transformer_layer(

def test_forward(
self,
input: [Tensor, Tensor, Tensor],
input: [torch.Tensor, torch.Tensor, torch.Tensor],
transformer_layer: TransformerSelfAttentionLayer,
) -> None:
input_x, input_y, mask = input
Expand Down Expand Up @@ -215,7 +215,7 @@ def input_params(self) -> Tuple[int, int, int]:
return batch_size, seq_len, vocab_size

@pytest.fixture
def input(self, input_params: Tuple[int, int, int]) -> Tensor:
def input(self, input_params: Tuple[int, int, int]) -> torch.Tensor:
batch_size, seq_len, vocab_size = input_params
return torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len))

Expand All @@ -234,7 +234,7 @@ def input_max_len_exceeded(
self,
input_params: Tuple[int, int, int],
decoder_params: Tuple[int, int, int, int, int, int],
) -> Tensor:
) -> torch.Tensor:
batch_size, seq_len, vocab_size = input_params
_, _, _, _, max_seq_len, _ = decoder_params
seq_len = max_seq_len + 1
Expand All @@ -245,7 +245,7 @@ def input_max_bs_exceeded(
self,
input_params: Tuple[int, int, int],
decoder_params: Tuple[int, int, int, int, int, int],
) -> Tensor:
) -> torch.Tensor:
batch_size, seq_len, vocab_size = input_params
_, _, _, _, max_seq_len, _ = decoder_params
batch_size = batch_size + 1
Expand Down Expand Up @@ -306,7 +306,7 @@ def decoder_with_kv_cache_enabled(

def test_forward(
self,
input: Tensor,
input: torch.Tensor,
input_params: Tuple[int, int, int],
decoder: TransformerDecoder,
) -> None:
Expand All @@ -318,15 +318,15 @@ def test_forward(

def test_max_seq_len_exceeded(
self,
input_max_len_exceeded: Tensor,
input_max_len_exceeded: torch.Tensor,
decoder: TransformerDecoder,
) -> None:
with pytest.raises(Exception):
output = decoder(input_max_len_exceeded)

def test_kv_cache(
self,
input: Tensor,
input: torch.Tensor,
decoder_with_kv_cache_enabled: TransformerDecoder,
decoder: TransformerDecoder,
) -> None:
Expand All @@ -340,7 +340,7 @@ def test_kv_cache(

def test_kv_cache_reset_values(
self,
input: Tensor,
input: torch.Tensor,
decoder_with_kv_cache_enabled: TransformerDecoder,
) -> None:
_, seq_len = input.shape
Expand Down Expand Up @@ -375,7 +375,7 @@ def test_kv_cache_reset_values_fails_when_not_enabled_first(

def test_kv_cache_batch_size_exceeded(
self,
input_max_bs_exceeded: Tensor,
input_max_bs_exceeded: torch.Tensor,
decoder_with_kv_cache_enabled: TransformerDecoder,
) -> None:
with pytest.raises(ValueError):
Expand Down
10 changes: 5 additions & 5 deletions torchtune/models/clip/_position_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self, embed_dim: int, tile_size: int, patch_size: int) -> None:
def forward(self, x: torch.Tensor, *args: Tuple[Any]) -> torch.Tensor:
"""
Args:
x (torch.Tensor): Tensor with shape (..., n_tokens, embed_dim)
x (torch.Tensor): torch.Tensor with shape (..., n_tokens, embed_dim)
*args (Tuple[Any]): Optional args.
Returns:
Expand Down Expand Up @@ -103,8 +103,8 @@ def __init__(
def forward(self, x: torch.Tensor, aspect_ratio: torch.Tensor) -> torch.Tensor:
"""
Args:
x (torch.Tensor): Tensor with shape (bsz * n_imgs, n_tiles, n_tokens, embed_dim).
aspect_ratio (torch.Tensor): Tensor with shape (bsz * n_imgs, 2),
x (torch.Tensor): torch.Tensor with shape (bsz * n_imgs, n_tiles, n_tokens, embed_dim).
aspect_ratio (torch.Tensor): torch.Tensor with shape (bsz * n_imgs, 2),
where aspect_ratio[k] represents the aspect ratio of the k^th image
of the batch before tile-cropping, e.g. aspect_ratio[k] = (2,1).
Returns:
Expand Down Expand Up @@ -169,8 +169,8 @@ def __init__(
def forward(self, x: torch.Tensor, aspect_ratio: torch.Tensor) -> torch.Tensor:
"""
args:
x (torch.Tensor): Tensor with shape (bsz * n_imgs, n_tiles, n_tokens, embed_dim).
aspect_ratio (torch.Tensor): Tensor with shape (bsz * n_imgs, 2),
x (torch.Tensor): torch.Tensor with shape (bsz * n_imgs, n_tiles, n_tokens, embed_dim).
aspect_ratio (torch.Tensor): torch.Tensor with shape (bsz * n_imgs, 2),
representing the aspect ratio of the image before tile-cropping, e.g. (2,1).
returns:
torch.Tensor: The input tensor with added positional embeddings.
Expand Down
15 changes: 7 additions & 8 deletions torchtune/models/gemma/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torchtune.modules import KVCache

from torchtune.modules.transformer import _get_clones, TransformerSelfAttentionLayer
Expand Down Expand Up @@ -101,20 +100,20 @@ def setup_caches(self, batch_size: int, dtype: torch.dtype) -> None:

def forward(
self,
tokens: Tensor,
tokens: torch.Tensor,
*,
mask: Optional[Tensor] = None,
input_pos: Optional[Tensor] = None,
) -> Tensor:
mask: Optional[torch.Tensor] = None,
input_pos: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Args:
tokens (Tensor): input tensor with shape [b x s]
mask (Optional[Tensor]): Optional boolean tensor which contains the attention mask
tokens (torch.Tensor): input tensor with shape [b x s]
mask (Optional[torch.Tensor]): Optional boolean tensor which contains the attention mask
with shape [b x s x s]. This is applied after the query-key multiplication and
before the softmax. A value of True in row i and column j means token i attends
to token j. A value of False means token i does not attend to token j. If no
mask is specified, a causal mask is used by default. Default is None.
input_pos (Optional[Tensor]): Optional tensor which contains the position ids
input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids
of each token. During training, this is used to indicate the positions
of each token relative to its sample when packed, shape [b x s].
During inference, this indicates the position of the current token.
Expand Down
Loading

0 comments on commit 7044044

Please sign in to comment.