From 70440446a4acf53e05cf7d74988fab21c8fd32e3 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Mon, 2 Sep 2024 19:46:51 +0100 Subject: [PATCH] Standardize `torch.Tensor` typing (#1471) --- docs/source/tutorials/lora_finetune.rst | 5 +- docs/source/tutorials/qlora_finetune.rst | 5 +- .../llama2/scripts/compare_fused_attention.py | 11 +-- tests/torchtune/modules/test_attention.py | 24 +++--- tests/torchtune/modules/test_feed_forward.py | 6 +- .../modules/test_transformer_decoder.py | 26 +++--- torchtune/models/clip/_position_embeddings.py | 10 +-- torchtune/models/gemma/transformer.py | 15 ++-- .../models/llama3_1/_position_embeddings.py | 13 +-- torchtune/models/phi3/_position_embeddings.py | 10 ++- .../models/qwen2/_positional_embeddings.py | 10 ++- torchtune/modules/attention.py | 20 ++--- torchtune/modules/feed_forward.py | 5 +- torchtune/modules/kv_cache.py | 14 ++-- torchtune/modules/low_precision/nf4_linear.py | 5 +- torchtune/modules/model_fusion/_fusion.py | 28 +++---- torchtune/modules/peft/dora.py | 6 +- torchtune/modules/peft/lora.py | 9 +- torchtune/modules/position_embeddings.py | 13 +-- torchtune/modules/rlhf/rewards.py | 2 +- torchtune/modules/rms_norm.py | 8 +- torchtune/modules/tanh_gate.py | 8 +- torchtune/modules/transformer.py | 84 +++++++++---------- .../transforms/vision_utils/tile_crop.py | 2 +- torchtune/modules/vision_transformer.py | 4 +- torchtune/training/metric_logging.py | 7 +- torchtune/utils/_distributed.py | 2 +- 27 files changed, 182 insertions(+), 170 deletions(-) diff --git a/docs/source/tutorials/lora_finetune.rst b/docs/source/tutorials/lora_finetune.rst index b625a6ada4..31ca61a137 100644 --- a/docs/source/tutorials/lora_finetune.rst +++ b/docs/source/tutorials/lora_finetune.rst @@ -84,7 +84,8 @@ Let's take a look at a minimal implementation of LoRA in native PyTorch. .. code-block:: python - from torch import nn, Tensor + import torch + from torch import nn class LoRALinear(nn.Module): def __init__( @@ -114,7 +115,7 @@ Let's take a look at a minimal implementation of LoRA in native PyTorch. self.lora_a.weight.requires_grad = True self.lora_b.weight.requires_grad = True - def forward(self, x: Tensor) -> Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: # This would be the output of the original model frozen_out = self.linear(x) diff --git a/docs/source/tutorials/qlora_finetune.rst b/docs/source/tutorials/qlora_finetune.rst index 6237ddc2b8..ff887bc39f 100644 --- a/docs/source/tutorials/qlora_finetune.rst +++ b/docs/source/tutorials/qlora_finetune.rst @@ -217,7 +217,8 @@ a vanilla minimal LoRA layer, taken from :ref:`the LoRA tutorial Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: # frozen_out would be the output of the original model if quantize_base: # Call into torchao's linear_nf4 to run linear forward pass w/quantized weight. diff --git a/tests/torchtune/models/llama2/scripts/compare_fused_attention.py b/tests/torchtune/models/llama2/scripts/compare_fused_attention.py index e6cd483f7f..328d1c528f 100644 --- a/tests/torchtune/models/llama2/scripts/compare_fused_attention.py +++ b/tests/torchtune/models/llama2/scripts/compare_fused_attention.py @@ -11,6 +11,7 @@ from torch import nn, Tensor from torchtune.modules import KVCache, MultiHeadAttention, RotaryPositionalEmbeddings + # Copy-paste of fused attention for comparison class FusedMultiHeadAttention(nn.Module): """Multi-headed grouped query self-attention (GQA) layer introduced @@ -115,15 +116,15 @@ def __init__( def forward( self, - x: Tensor, - mask: Optional[Tensor] = None, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, curr_pos: int = 0, - ) -> Tensor: + ) -> torch.Tensor: """ Args: x (Tensor): input tensor with shape [batch_size x seq_length x embed_dim] - mask (Optional[Tensor]): boolean mask, defaults to None. + mask (Optional[torch.Tensor]): boolean mask, defaults to None. curr_pos (int): current position in the sequence, defaults to 0. Returns: @@ -241,7 +242,7 @@ def map_state_dict( return mapped_sd -def _get_mask(inpt: Tensor) -> Tensor: +def _get_mask(inpt: torch.Tensor) -> torch.Tensor: seq_len = inpt.shape[1] mask = torch.full((1, 1, seq_len, seq_len), float("-inf"), device=inpt.device) mask = torch.triu(mask, diagonal=1).type_as(inpt) diff --git a/tests/torchtune/modules/test_attention.py b/tests/torchtune/modules/test_attention.py index dc6466e099..47e5653427 100644 --- a/tests/torchtune/modules/test_attention.py +++ b/tests/torchtune/modules/test_attention.py @@ -11,7 +11,7 @@ import torch from tests.test_utils import assert_expected, fixed_init_model -from torch import nn, Tensor +from torch import nn from torchtune.modules import KVCache, MultiHeadAttention, RotaryPositionalEmbeddings from torchtune.utils.seed import set_seed @@ -40,7 +40,7 @@ def input_params(self) -> Tuple[int, int, int]: return batch_size, seq_len, embed_dim @pytest.fixture - def input(self, input_params: Tuple[int, int, int]) -> Tensor: + def input(self, input_params: Tuple[int, int, int]) -> torch.Tensor: batch_size, seq_len, embed_dim = input_params x = torch.randn(batch_size, seq_len, embed_dim) return x @@ -58,7 +58,7 @@ def input_max_len_exceeded( self, input_params: Tuple[int, int, int], attn_params_gqa: Tuple[int, int, int, int], - ) -> Tensor: + ) -> torch.Tensor: batch_size, seq_len, embed_dim = input_params _, _, _, max_seq_len = attn_params_gqa seq_len = max_seq_len + 1 @@ -69,7 +69,7 @@ def input_max_bs_exceeded( self, input_params: Tuple[int, int, int], attn_params_gqa: Tuple[int, int, int, int], - ) -> Tensor: + ) -> torch.Tensor: batch_size, seq_len, embed_dim = input_params _, _, _, max_seq_len = attn_params_gqa batch_size += 1 @@ -253,7 +253,7 @@ def mqa_kv_cache( attn.eval() return attn - def test_forward_gqa(self, input: Tensor, gqa: MultiHeadAttention) -> None: + def test_forward_gqa(self, input: torch.Tensor, gqa: MultiHeadAttention) -> None: with torch.no_grad(): output = gqa(input) assert_expected( @@ -262,7 +262,7 @@ def test_forward_gqa(self, input: Tensor, gqa: MultiHeadAttention) -> None: assert_expected(output.shape, input.shape) def test_forward_gqa_kv_cache( - self, input: Tensor, gqa_kv_cache: MultiHeadAttention, attn_params_gqa + self, input: torch.Tensor, gqa_kv_cache: MultiHeadAttention, attn_params_gqa ) -> None: _, _, _, max_seq_len = attn_params_gqa @@ -279,7 +279,7 @@ def test_forward_gqa_kv_cache( ) assert_expected(output.shape, input.shape) - def test_forward_mha(self, input: Tensor, mha: MultiHeadAttention) -> None: + def test_forward_mha(self, input: torch.Tensor, mha: MultiHeadAttention) -> None: with torch.no_grad(): output = mha(input) assert_expected( @@ -288,7 +288,7 @@ def test_forward_mha(self, input: Tensor, mha: MultiHeadAttention) -> None: assert_expected(output.shape, input.shape) def test_forward_mha_kv_cache( - self, input: Tensor, mha_kv_cache: MultiHeadAttention, attn_params_mha + self, input: torch.Tensor, mha_kv_cache: MultiHeadAttention, attn_params_mha ) -> None: _, _, _, max_seq_len = attn_params_mha @@ -305,7 +305,7 @@ def test_forward_mha_kv_cache( ) assert_expected(output.shape, input.shape) - def test_forward_mqa(self, input: Tensor, mqa: MultiHeadAttention) -> None: + def test_forward_mqa(self, input: torch.Tensor, mqa: MultiHeadAttention) -> None: with torch.no_grad(): output = mqa(input) assert_expected( @@ -314,7 +314,7 @@ def test_forward_mqa(self, input: Tensor, mqa: MultiHeadAttention) -> None: assert_expected(output.shape, input.shape) def test_forward_mqa_kv_cache( - self, input: Tensor, mqa_kv_cache: MultiHeadAttention, attn_params_mqa + self, input: torch.Tensor, mqa_kv_cache: MultiHeadAttention, attn_params_mqa ) -> None: _, _, _, max_seq_len = attn_params_mqa _, seq_len, _ = input.shape @@ -332,7 +332,7 @@ def test_forward_mqa_kv_cache( def test_max_seq_len_exceeded( self, - input_max_len_exceeded: Tensor, + input_max_len_exceeded: torch.Tensor, gqa: MultiHeadAttention, ) -> None: with pytest.raises(Exception): @@ -340,7 +340,7 @@ def test_max_seq_len_exceeded( def test_batch_size_exceeded( self, - input_max_bs_exceeded: Tensor, + input_max_bs_exceeded: torch.Tensor, gqa_kv_cache: MultiHeadAttention, ) -> None: with pytest.raises(Exception): diff --git a/tests/torchtune/modules/test_feed_forward.py b/tests/torchtune/modules/test_feed_forward.py index 53b1aed593..050fb44ed4 100644 --- a/tests/torchtune/modules/test_feed_forward.py +++ b/tests/torchtune/modules/test_feed_forward.py @@ -11,7 +11,7 @@ import torch from tests.test_utils import assert_expected, fixed_init_model -from torch import nn, Tensor +from torch import nn from torchtune.modules import FeedForward from torchtune.utils.seed import set_seed @@ -32,7 +32,7 @@ def input_params(self) -> Tuple[int, int]: return dim, hidden_dim @pytest.fixture - def input(self, input_params: Tuple[int, int]) -> Tensor: + def input(self, input_params: Tuple[int, int]) -> torch.Tensor: dim, _ = input_params return torch.randn(1, dim) @@ -49,7 +49,7 @@ def ffn(self, input_params: Tuple[int, int]) -> FeedForward: ff.eval() return ff - def test_forward(self, input: Tensor, ffn: FeedForward) -> None: + def test_forward(self, input: torch.Tensor, ffn: FeedForward) -> None: with torch.no_grad(): x_out = ffn(input) assert_expected(x_out.mean(), torch.tensor(251.5356), atol=1e-7, rtol=1e-3) diff --git a/tests/torchtune/modules/test_transformer_decoder.py b/tests/torchtune/modules/test_transformer_decoder.py index 3cc9ff8c0b..e7bd9c5197 100644 --- a/tests/torchtune/modules/test_transformer_decoder.py +++ b/tests/torchtune/modules/test_transformer_decoder.py @@ -12,7 +12,7 @@ from tests.test_utils import assert_expected -from torch import nn, Tensor +from torch import nn from torchtune.models.llama2 import llama2 from torchtune.models.llama2._component_builders import llama2_mlp @@ -54,7 +54,7 @@ def input_params(self) -> Tuple[int, int, int]: return batch_size, seq_len, embed_dim @pytest.fixture - def input(self, input_params: Tuple[int, int, int]) -> Tensor: + def input(self, input_params: Tuple[int, int, int]) -> torch.Tensor: batch_size, seq_len, embed_dim = input_params return torch.randn(batch_size, seq_len, embed_dim) @@ -100,7 +100,7 @@ def transformer_layer( return transformer_layer def test_forward( - self, input: Tensor, transformer_layer: TransformerSelfAttentionLayer + self, input: torch.Tensor, transformer_layer: TransformerSelfAttentionLayer ) -> None: with torch.no_grad(): output = transformer_layer(input) @@ -125,7 +125,7 @@ def input_params(self) -> Tuple[int, int, int, int]: return batch_size, seq_len, encoder_seq_len, embed_dim @pytest.fixture - def input(self, input_params: Tuple[int, int, int, int]) -> Tensor: + def input(self, input_params: Tuple[int, int, int, int]) -> torch.Tensor: batch_size, seq_len, encoder_seq_len, embed_dim = input_params rand_x = torch.randn(batch_size, seq_len, embed_dim) rand_y = torch.randn(batch_size, 128, embed_dim) @@ -185,7 +185,7 @@ def transformer_layer( def test_forward( self, - input: [Tensor, Tensor, Tensor], + input: [torch.Tensor, torch.Tensor, torch.Tensor], transformer_layer: TransformerSelfAttentionLayer, ) -> None: input_x, input_y, mask = input @@ -215,7 +215,7 @@ def input_params(self) -> Tuple[int, int, int]: return batch_size, seq_len, vocab_size @pytest.fixture - def input(self, input_params: Tuple[int, int, int]) -> Tensor: + def input(self, input_params: Tuple[int, int, int]) -> torch.Tensor: batch_size, seq_len, vocab_size = input_params return torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len)) @@ -234,7 +234,7 @@ def input_max_len_exceeded( self, input_params: Tuple[int, int, int], decoder_params: Tuple[int, int, int, int, int, int], - ) -> Tensor: + ) -> torch.Tensor: batch_size, seq_len, vocab_size = input_params _, _, _, _, max_seq_len, _ = decoder_params seq_len = max_seq_len + 1 @@ -245,7 +245,7 @@ def input_max_bs_exceeded( self, input_params: Tuple[int, int, int], decoder_params: Tuple[int, int, int, int, int, int], - ) -> Tensor: + ) -> torch.Tensor: batch_size, seq_len, vocab_size = input_params _, _, _, _, max_seq_len, _ = decoder_params batch_size = batch_size + 1 @@ -306,7 +306,7 @@ def decoder_with_kv_cache_enabled( def test_forward( self, - input: Tensor, + input: torch.Tensor, input_params: Tuple[int, int, int], decoder: TransformerDecoder, ) -> None: @@ -318,7 +318,7 @@ def test_forward( def test_max_seq_len_exceeded( self, - input_max_len_exceeded: Tensor, + input_max_len_exceeded: torch.Tensor, decoder: TransformerDecoder, ) -> None: with pytest.raises(Exception): @@ -326,7 +326,7 @@ def test_max_seq_len_exceeded( def test_kv_cache( self, - input: Tensor, + input: torch.Tensor, decoder_with_kv_cache_enabled: TransformerDecoder, decoder: TransformerDecoder, ) -> None: @@ -340,7 +340,7 @@ def test_kv_cache( def test_kv_cache_reset_values( self, - input: Tensor, + input: torch.Tensor, decoder_with_kv_cache_enabled: TransformerDecoder, ) -> None: _, seq_len = input.shape @@ -375,7 +375,7 @@ def test_kv_cache_reset_values_fails_when_not_enabled_first( def test_kv_cache_batch_size_exceeded( self, - input_max_bs_exceeded: Tensor, + input_max_bs_exceeded: torch.Tensor, decoder_with_kv_cache_enabled: TransformerDecoder, ) -> None: with pytest.raises(ValueError): diff --git a/torchtune/models/clip/_position_embeddings.py b/torchtune/models/clip/_position_embeddings.py index 05897aaccf..580856cd1e 100644 --- a/torchtune/models/clip/_position_embeddings.py +++ b/torchtune/models/clip/_position_embeddings.py @@ -42,7 +42,7 @@ def __init__(self, embed_dim: int, tile_size: int, patch_size: int) -> None: def forward(self, x: torch.Tensor, *args: Tuple[Any]) -> torch.Tensor: """ Args: - x (torch.Tensor): Tensor with shape (..., n_tokens, embed_dim) + x (torch.Tensor): torch.Tensor with shape (..., n_tokens, embed_dim) *args (Tuple[Any]): Optional args. Returns: @@ -103,8 +103,8 @@ def __init__( def forward(self, x: torch.Tensor, aspect_ratio: torch.Tensor) -> torch.Tensor: """ Args: - x (torch.Tensor): Tensor with shape (bsz * n_imgs, n_tiles, n_tokens, embed_dim). - aspect_ratio (torch.Tensor): Tensor with shape (bsz * n_imgs, 2), + x (torch.Tensor): torch.Tensor with shape (bsz * n_imgs, n_tiles, n_tokens, embed_dim). + aspect_ratio (torch.Tensor): torch.Tensor with shape (bsz * n_imgs, 2), where aspect_ratio[k] represents the aspect ratio of the k^th image of the batch before tile-cropping, e.g. aspect_ratio[k] = (2,1). Returns: @@ -169,8 +169,8 @@ def __init__( def forward(self, x: torch.Tensor, aspect_ratio: torch.Tensor) -> torch.Tensor: """ args: - x (torch.Tensor): Tensor with shape (bsz * n_imgs, n_tiles, n_tokens, embed_dim). - aspect_ratio (torch.Tensor): Tensor with shape (bsz * n_imgs, 2), + x (torch.Tensor): torch.Tensor with shape (bsz * n_imgs, n_tiles, n_tokens, embed_dim). + aspect_ratio (torch.Tensor): torch.Tensor with shape (bsz * n_imgs, 2), representing the aspect ratio of the image before tile-cropping, e.g. (2,1). returns: torch.Tensor: The input tensor with added positional embeddings. diff --git a/torchtune/models/gemma/transformer.py b/torchtune/models/gemma/transformer.py index 2ed5d78962..1b4c68f40b 100644 --- a/torchtune/models/gemma/transformer.py +++ b/torchtune/models/gemma/transformer.py @@ -9,7 +9,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch import Tensor from torchtune.modules import KVCache from torchtune.modules.transformer import _get_clones, TransformerSelfAttentionLayer @@ -101,20 +100,20 @@ def setup_caches(self, batch_size: int, dtype: torch.dtype) -> None: def forward( self, - tokens: Tensor, + tokens: torch.Tensor, *, - mask: Optional[Tensor] = None, - input_pos: Optional[Tensor] = None, - ) -> Tensor: + mask: Optional[torch.Tensor] = None, + input_pos: Optional[torch.Tensor] = None, + ) -> torch.Tensor: """ Args: - tokens (Tensor): input tensor with shape [b x s] - mask (Optional[Tensor]): Optional boolean tensor which contains the attention mask + tokens (torch.Tensor): input tensor with shape [b x s] + mask (Optional[torch.Tensor]): Optional boolean tensor which contains the attention mask with shape [b x s x s]. This is applied after the query-key multiplication and before the softmax. A value of True in row i and column j means token i attends to token j. A value of False means token i does not attend to token j. If no mask is specified, a causal mask is used by default. Default is None. - input_pos (Optional[Tensor]): Optional tensor which contains the position ids + input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids of each token. During training, this is used to indicate the positions of each token relative to its sample when packed, shape [b x s]. During inference, this indicates the position of the current token. diff --git a/torchtune/models/llama3_1/_position_embeddings.py b/torchtune/models/llama3_1/_position_embeddings.py index 21e3788964..8547919cd8 100644 --- a/torchtune/models/llama3_1/_position_embeddings.py +++ b/torchtune/models/llama3_1/_position_embeddings.py @@ -9,7 +9,7 @@ import torch -from torch import nn, Tensor +from torch import nn class Llama3ScaledRoPE(nn.Module): @@ -74,7 +74,8 @@ def build_rope_cache(self, max_seq_len: int = 4096) -> None: def apply_scaling(self, freqs: torch.Tensor): """From the following Meta-Llama code: - https://github.com/meta-llama/llama-models/blob/dc42f22a3b05502e7296402b019a51f57fa045c9/models/llama3_1/api/model.py#L41""" + https://github.com/meta-llama/llama-models/blob/dc42f22a3b05502e7296402b019a51f57fa045c9/models/llama3_1/api/model.py#L41 + """ # Values obtained from grid search scale_factor = 8 low_freq_factor = 1 @@ -98,12 +99,14 @@ def apply_scaling(self, freqs: torch.Tensor): new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) - def forward(self, x: Tensor, *, input_pos: Optional[Tensor] = None) -> Tensor: + def forward( + self, x: torch.Tensor, *, input_pos: Optional[torch.Tensor] = None + ) -> torch.Tensor: """ Args: - x (Tensor): input tensor with shape + x (torch.Tensor): input tensor with shape [b, s, n_h, h_d] - input_pos (Optional[Tensor]): Optional tensor which contains the position ids + input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids of each token. During training, this is used to indicate the positions of each token relative to its sample when packed, shape [b, s]. During inference, this indicates the position of the current token. diff --git a/torchtune/models/phi3/_position_embeddings.py b/torchtune/models/phi3/_position_embeddings.py index 271f6eea9d..3a935147fc 100644 --- a/torchtune/models/phi3/_position_embeddings.py +++ b/torchtune/models/phi3/_position_embeddings.py @@ -8,7 +8,7 @@ import torch -from torch import nn, Tensor +from torch import nn class Phi3RotaryPositionalEmbeddings(nn.Module): @@ -65,12 +65,14 @@ def build_rope_cache(self, max_seq_len: int = 4096) -> None: cache = torch.cat([freqs.cos(), freqs.sin()], dim=-1) self.register_buffer("cache", cache, persistent=False) - def forward(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: + def forward( + self, x: torch.Tensor, input_pos: Optional[torch.Tensor] = None + ) -> torch.Tensor: """ Args: - x (Tensor): input tensor with shape + x (torch.Tensor): input tensor with shape [b, s, n_h, h_d] - input_pos (Optional[Tensor]): Optional tensor which contains the position ids + input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids of each token. During training, this is used to indicate the positions of each token relative to its sample when packed, shape [b, s]. During inference, this indicates the position of the current token. diff --git a/torchtune/models/qwen2/_positional_embeddings.py b/torchtune/models/qwen2/_positional_embeddings.py index 78ad17a43a..61e8682783 100644 --- a/torchtune/models/qwen2/_positional_embeddings.py +++ b/torchtune/models/qwen2/_positional_embeddings.py @@ -8,7 +8,7 @@ import torch -from torch import nn, Tensor +from torch import nn class Qwen2RotaryPositionalEmbeddings(nn.Module): @@ -65,12 +65,14 @@ def build_rope_cache(self, max_seq_len: int = 4096) -> None: cache = torch.cat([freqs.cos(), freqs.sin()], dim=-1) self.register_buffer("cache", cache, persistent=False) - def forward(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: + def forward( + self, x: torch.Tensor, input_pos: Optional[torch.Tensor] = None + ) -> torch.Tensor: """ Args: - x (Tensor): input tensor with shape + x (torch.Tensor): input tensor with shape [b, s, n_h, h_d] - input_pos (Optional[Tensor]): Optional tensor which contains the position ids + input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids of each token. During training, this is used to indicate the positions of each token relative to its sample when packed, shape [b, s]. During inference, this indicates the position of the current token. diff --git a/torchtune/modules/attention.py b/torchtune/modules/attention.py index 99ddb17b1c..354f4943f1 100644 --- a/torchtune/modules/attention.py +++ b/torchtune/modules/attention.py @@ -8,7 +8,7 @@ from typing import Optional import torch -from torch import nn, Tensor +from torch import nn from torchtune.modules.kv_cache import KVCache logger = logging.getLogger(__name__) @@ -168,23 +168,23 @@ def reset_cache(self): def forward( self, - x: Tensor, - y: Optional[Tensor] = None, + x: torch.Tensor, + y: Optional[torch.Tensor] = None, *, - mask: Optional[Tensor] = None, - input_pos: Optional[Tensor] = None, - ) -> Tensor: + mask: Optional[torch.Tensor] = None, + input_pos: Optional[torch.Tensor] = None, + ) -> torch.Tensor: """ Args: - x (Tensor): input tensor with shape [b x s_x x d] - y (Optional[Tensor]): second input tensor for cross attention with shape [b x s_y x d] - mask (Optional[Tensor]): Optional boolean tensor which contains the attention mask + x (torch.Tensor): input tensor with shape [b x s_x x d] + y (Optional[torch.Tensor]): second input tensor for cross attention with shape [b x s_y x d] + mask (Optional[torch.Tensor]): Optional boolean tensor which contains the attention mask with shape [batch_size x seq_length x seq_length]. This is applied after the query-key multiplication and before the softmax. A value of True in row i and column j means token i attends to token j. A value of False means token i does not attend to token j. If no mask is specified, a causal mask is used by default. Default is None. - input_pos (Optional[Tensor]): Optional tensor which contains the position ids + input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids of each token. During training, this is used to indicate the positions of each token relative to its sample when packed, shape [b x s]. During inference, this indicates the position of the current token. diff --git a/torchtune/modules/feed_forward.py b/torchtune/modules/feed_forward.py index c69cd17ae6..fedb7bb608 100644 --- a/torchtune/modules/feed_forward.py +++ b/torchtune/modules/feed_forward.py @@ -4,7 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from torch import nn, Tensor +import torch +from torch import nn class FeedForward(nn.Module): @@ -33,5 +34,5 @@ def __init__( self.w3 = up_proj self.activation = activation - def forward(self, x: Tensor) -> Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w2(self.activation(self.w1(x)) * self.w3(x)) diff --git a/torchtune/modules/kv_cache.py b/torchtune/modules/kv_cache.py index 1ad55a4b8e..06b85898e8 100644 --- a/torchtune/modules/kv_cache.py +++ b/torchtune/modules/kv_cache.py @@ -7,7 +7,7 @@ from typing import Tuple import torch -from torch import nn, Tensor +from torch import nn class KVCache(nn.Module): @@ -49,19 +49,19 @@ def reset(self) -> None: self.v_cache.zero_() def update( - self, input_pos: Tensor, k_val: Tensor, v_val: Tensor - ) -> Tuple[Tensor, Tensor]: + self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: """Update KV cache with the new k_val, v_val and return the updated cache. Raises an assertion error if ``input_pos`` is longer than the maximum sequence length. Args: - input_pos (Tensor): Current position tensor with shape [S] - k_val (Tensor): Current key tensor with shape [B, H, S, D] - v_val (Tensor): Current value tensor with shape [B, H, S, D] + input_pos (torch.Tensor): Current position tensor with shape [S] + k_val (torch.Tensor): Current key tensor with shape [B, H, S, D] + v_val (torch.Tensor): Current value tensor with shape [B, H, S, D] Returns: - Tuple[Tensor, Tensor]: Updated KV cache with key first + Tuple[torch.Tensor, torch.Tensor]: Updated KV cache with key first """ assert input_pos.shape[0] == k_val.shape[2] self.size = input_pos.max().item() + 1 diff --git a/torchtune/modules/low_precision/nf4_linear.py b/torchtune/modules/low_precision/nf4_linear.py index 6626688d45..9b0eaf53a3 100644 --- a/torchtune/modules/low_precision/nf4_linear.py +++ b/torchtune/modules/low_precision/nf4_linear.py @@ -9,7 +9,6 @@ import torch import torch.nn as nn -from torch import Tensor from torchao.dtypes.nf4tensor import linear_nf4, to_nf4 @@ -47,13 +46,13 @@ def __init__( self.weight, torch.nn.Parameter(self.nf4_weight, requires_grad=False) ) - def forward(self, input: Tensor) -> Tensor: + def forward(self, input: torch.Tensor) -> torch.Tensor: """ Runs linear operation with input tensor as given by `input`. Computation happens in higher precision, though only the nf4 weight is saved for backward for gradient computation to ensure additional memory is not used. Args: - input (Tensor): input tensor + input (torch.Tensor): input tensor Returns: Tensor: output tensor diff --git a/torchtune/modules/model_fusion/_fusion.py b/torchtune/modules/model_fusion/_fusion.py index 689d823dec..e3a9d708b6 100644 --- a/torchtune/modules/model_fusion/_fusion.py +++ b/torchtune/modules/model_fusion/_fusion.py @@ -7,7 +7,7 @@ from typing import Dict, List, Optional, Union import torch -from torch import nn, Tensor +from torch import nn from torchtune.modules import TransformerDecoder @@ -116,10 +116,10 @@ def fusion_params(self) -> List[str]: ] return fusion_params - def forward(self, x: Tensor, **kwargs: Dict) -> Tensor: + def forward(self, x: torch.Tensor, **kwargs: Dict) -> torch.Tensor: """ Args: - x (Tensor): input tensor with shape + x (torch.Tensor): input tensor with shape [batch_size x seq_length x embed_dim] **kwargs (Dict): all additional layer args @@ -219,10 +219,10 @@ def _fused_embed(self, bs, seq_len): dtype = self.embedding.weight.dtype return torch.empty(bs, seq_len, self.dim, device=device, dtype=dtype) - def forward(self, input: Tensor) -> Tensor: + def forward(self, input: torch.Tensor) -> torch.Tensor: """ Args: - input (Tensor): input integer tensor with shape + input (torch.Tensor): input integer tensor with shape [batch_size x seq_length] Returns: @@ -323,26 +323,26 @@ def reset_caches(self): def forward( self, - tokens: Tensor, + tokens: torch.Tensor, *, - mask: Optional[Tensor] = None, + mask: Optional[torch.Tensor] = None, encoder_input: Optional[Dict] = None, - encoder_mask: Optional[Tensor] = None, - input_pos: Optional[Tensor] = None, - ) -> Union[Tensor, List[Tensor]]: + encoder_mask: Optional[torch.Tensor] = None, + input_pos: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, List[torch.Tensor]]: """ Args: - tokens (Tensor): input tensor with shape [b x s] - mask (Optional[Tensor]): Optional boolean tensor which contains the attention mask + tokens (torch.Tensor): input tensor with shape [b x s] + mask (Optional[torch.Tensor]): Optional boolean tensor which contains the attention mask with shape [b x s x s]. This is applied after the query-key multiplication and before the softmax. A value of True in row i and column j means token i attends to token j. A value of False means token i does not attend to token j. If no mask is specified, a causal mask is used by default. Default is None. encoder_input (Optional[Dict]): Optional input for the encoder. - encoder_mask (Optional[Tensor]): Boolean tensor defining a relational matrix between + encoder_mask (Optional[torch.Tensor]): Boolean tensor defining a relational matrix between tokens and encoder embeddings. A True value at position i,j means token i can attend to embedding j in the decoder. Mask has shape [b x s x s_e]. Default is None. - input_pos (Optional[Tensor]): Optional tensor which contains the position ids + input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids of each token. During training, this is used to indicate the positions of each token relative to its sample when packed, shape [b x s]. During inference, this indicates the position of the current token. diff --git a/torchtune/modules/peft/dora.py b/torchtune/modules/peft/dora.py index 25ff63e609..d8ef8016b1 100644 --- a/torchtune/modules/peft/dora.py +++ b/torchtune/modules/peft/dora.py @@ -10,7 +10,7 @@ import torch import torch.nn.functional as F -from torch import nn, Tensor +from torch import nn from torchao.dtypes.nf4tensor import linear_nf4, to_nf4 from torchtune.modules.low_precision import _register_nf4_dispatch_ops # noqa: F401 @@ -113,10 +113,10 @@ def adapter_params(self) -> List[str]: adapter_params = ["lora_a.weight", "lora_b.weight", "magnitude"] return adapter_params - def forward(self, x: Tensor) -> Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: - x (Tensor): input tensor with shape ``(..., in_dim)`` + x (torch.Tensor): input tensor with shape ``(..., in_dim)`` Returns: Tensor: output tensor with shape ``(..., out_dim)`` diff --git a/torchtune/modules/peft/lora.py b/torchtune/modules/peft/lora.py index 7c542deb17..9ecc676db3 100644 --- a/torchtune/modules/peft/lora.py +++ b/torchtune/modules/peft/lora.py @@ -6,9 +6,10 @@ import math from typing import List +import torch import torch.nn.functional as F -from torch import nn, Tensor +from torch import nn from torchao.dtypes.nf4tensor import linear_nf4, to_nf4 from torchtune.modules.low_precision import _register_nf4_dispatch_ops # noqa: F401 @@ -111,13 +112,13 @@ def adapter_params(self) -> List[str]: adapter_params = ["lora_a.weight", "lora_b.weight"] return adapter_params - def forward(self, x: Tensor) -> Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: - x (Tensor): input tensor with shape ``(..., in_dim)`` + x (torch.Tensor): input tensor with shape ``(..., in_dim)`` Returns: - Tensor: output tensor with shape ``(..., out_dim)`` + torch.Tensor: output tensor with shape ``(..., out_dim)`` """ if self._quantize_base: diff --git a/torchtune/modules/position_embeddings.py b/torchtune/modules/position_embeddings.py index 194b75ca9f..cd928730b0 100644 --- a/torchtune/modules/position_embeddings.py +++ b/torchtune/modules/position_embeddings.py @@ -7,8 +7,7 @@ from typing import Optional import torch - -from torch import nn, Tensor +from torch import nn class RotaryPositionalEmbeddings(nn.Module): @@ -72,19 +71,21 @@ def build_rope_cache(self, max_seq_len: int = 4096) -> None: cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) self.register_buffer("cache", cache, persistent=False) - def forward(self, x: Tensor, *, input_pos: Optional[Tensor] = None) -> Tensor: + def forward( + self, x: torch.Tensor, *, input_pos: Optional[torch.Tensor] = None + ) -> torch.Tensor: """ Args: - x (Tensor): input tensor with shape + x (torch.Tensor): input tensor with shape [b, s, n_h, h_d] - input_pos (Optional[Tensor]): Optional tensor which contains the position ids + input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids of each token. During training, this is used to indicate the positions of each token relative to its sample when packed, shape [b, s]. During inference, this indicates the position of the current token. If none, assume the index of the token is its position id. Default is None. Returns: - Tensor: output tensor with RoPE applied + torch.Tensor: output tensor with RoPE applied Notation used for tensor shapes: - b: batch size diff --git a/torchtune/modules/rlhf/rewards.py b/torchtune/modules/rlhf/rewards.py index 0e5994ff1d..4abb0742cf 100644 --- a/torchtune/modules/rlhf/rewards.py +++ b/torchtune/modules/rlhf/rewards.py @@ -26,7 +26,7 @@ def get_reward_penalty_mask( - If ``penalise_no_eos`` is True, scores for sequences with no EOS token are penalised. Args: - padding_masks (torch.Tensor): Tensor where True indicates a padding token in the generated + padding_masks (torch.Tensor): torch.Tensor where True indicates a padding token in the generated sequence, and False otherwise. Shape: (b, reponse_len) seq_lens (torch.Tensor): The length of each generated sequence. Shape: (b,) penalise_no_eos (bool, optional): Whether to penalise sequences with no EOS token. Defaults to True. diff --git a/torchtune/modules/rms_norm.py b/torchtune/modules/rms_norm.py index a2e4e2a7df..78e3e0a316 100644 --- a/torchtune/modules/rms_norm.py +++ b/torchtune/modules/rms_norm.py @@ -6,7 +6,7 @@ import torch -from torch import nn, Tensor +from torch import nn class RMSNorm(nn.Module): @@ -28,13 +28,13 @@ def __init__(self, dim: int, eps: float = 1e-6) -> None: self.eps = eps self.scale = nn.Parameter(torch.ones(dim)) - def forward(self, x: Tensor) -> Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: - x (Tensor): input tensor to normalize + x (torch.Tensor): input tensor to normalize Returns: - Tensor: The output tensor after applying RMSNorm. + torch.Tensor: The output tensor after applying RMSNorm. """ # computation is in fp32 x_fp32 = x.float() diff --git a/torchtune/modules/tanh_gate.py b/torchtune/modules/tanh_gate.py index 29a4813967..f877ad6776 100644 --- a/torchtune/modules/tanh_gate.py +++ b/torchtune/modules/tanh_gate.py @@ -6,7 +6,7 @@ import torch -from torch import nn, Tensor +from torch import nn class TanhGate(nn.Module): @@ -16,12 +16,12 @@ def __init__(self) -> None: super().__init__() self.scale = nn.Parameter(torch.zeros(1)) - def forward(self, x: Tensor) -> Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: - x (Tensor): input tensor to gate + x (torch.Tensor): input tensor to gate Returns: - Tensor: The output tensor after gating. + torch.Tensor: The output tensor after gating. """ return x * self.scale.tanh() diff --git a/torchtune/modules/transformer.py b/torchtune/modules/transformer.py index b9e88bbd05..9a22744424 100644 --- a/torchtune/modules/transformer.py +++ b/torchtune/modules/transformer.py @@ -8,7 +8,7 @@ import torch import torch.nn.functional as F -from torch import nn, Tensor +from torch import nn from torchtune.modules import MultiHeadAttention @@ -63,23 +63,23 @@ def reset_cache(self): def forward( self, - x: Tensor, + x: torch.Tensor, *, - mask: Optional[Tensor] = None, - input_pos: Optional[Tensor] = None, + mask: Optional[torch.Tensor] = None, + input_pos: Optional[torch.Tensor] = None, **kwargs: Dict, - ) -> Tensor: + ) -> torch.Tensor: """ Args: - x (Tensor): input tensor with shape + x (torch.Tensor): input tensor with shape [batch_size x seq_length x embed_dim] - mask (Optional[Tensor]): Optional boolean tensor which contains the attention mask + mask (Optional[torch.Tensor]): Optional boolean tensor which contains the attention mask with shape [batch_size x seq_length x seq_length]. This is applied after the query-key multiplication and before the softmax. A value of True in row i and column j means token i attends to token j. A value of False means token i does not attend to token j. If no mask is specified, a causal mask is used by default. Default is None. - input_pos (Optional[Tensor]): Optional tensor which contains the position ids + input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids of each token. During training, this is used to indicate the positions of each token relative to its sample when packed, shape [b x s]. During inference, this indicates the position of the current token. @@ -87,7 +87,7 @@ def forward( **kwargs (Dict): transformer layer inputs not relevant to self attention. Returns: - Tensor: output tensor with same shape as input + torch.Tensor: output tensor with same shape as input [batch_size x seq_length x embed_dim] TODO: @@ -166,7 +166,7 @@ def reset_cache(self): """Reset the key value caches.""" self.attn.reset_cache() - def _skip_mask(self, mask: Optional[Tensor]) -> Optional[Tensor]: + def _skip_mask(self, mask: Optional[torch.Tensor]) -> Optional[torch.Tensor]: """Some tokens in x may not attend to any encoder inputs due to the cross attention mask (encoder_mask). This results in a full row of the attention matrix being masked out. @@ -203,26 +203,26 @@ def _skip_mask(self, mask: Optional[Tensor]) -> Optional[Tensor]: def forward( self, - x: Tensor, + x: torch.Tensor, *, - encoder_input: Optional[Tensor] = None, - encoder_mask: Optional[Tensor] = None, + encoder_input: Optional[torch.Tensor] = None, + encoder_mask: Optional[torch.Tensor] = None, **kwargs: Dict, - ) -> Tensor: + ) -> torch.Tensor: """ Args: - x (Tensor): input tensor with shape + x (torch.Tensor): input tensor with shape [batch_size x seq_length x embed_dim] - encoder_input (Optional[Tensor]): Optional input embeds from the encoder. Shape + encoder_input (Optional[torch.Tensor]): Optional input embeds from the encoder. Shape [batch_size x token_sequence x embed_dim] - encoder_mask (Optional[Tensor]): Boolean tensor defining a relational matrix between + encoder_mask (Optional[torch.Tensor]): Boolean tensor defining a relational matrix between tokens and encoder embeddings. A True value at position i,j means token i can attend to embedding j in the decoder. Mask has shape [batch_size x token_sequence x embed_sequence]. Default is None. **kwargs (Dict): transformer layer inputs not relevant to self attention. Returns: - Tensor: output tensor with same shape as input + torch.Tensor: output tensor with same shape as input [batch_size x seq_length x embed_dim] """ # During decoding, it's possible encoder_input is None because the embeds @@ -377,26 +377,26 @@ def reset_caches(self): def forward( self, - tokens: Tensor, + tokens: torch.Tensor, *, - mask: Optional[Tensor] = None, - encoder_input: Optional[Tensor] = None, - encoder_mask: Optional[Tensor] = None, - input_pos: Optional[Tensor] = None, - ) -> Union[Tensor, List[Tensor]]: + mask: Optional[torch.Tensor] = None, + encoder_input: Optional[torch.Tensor] = None, + encoder_mask: Optional[torch.Tensor] = None, + input_pos: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, List[torch.Tensor]]: """ Args: - tokens (Tensor): input tensor with shape [b x s] - mask (Optional[Tensor]): Optional boolean tensor which contains the attention mask + tokens (torch.Tensor): input tensor with shape [b x s] + mask (Optional[torch.Tensor]): Optional boolean tensor which contains the attention mask with shape [b x s x s]. This is applied after the query-key multiplication and before the softmax. A value of True in row i and column j means token i attends to token j. A value of False means token i does not attend to token j. If no mask is specified, a causal mask is used by default. Default is None. - encoder_input (Optional[Tensor]): Optional input embeds from the encoder. Shape [b x s_e x d_e] - encoder_mask (Optional[Tensor]): Boolean tensor defining a relational matrix between + encoder_input (Optional[torch.Tensor]): Optional input embeds from the encoder. Shape [b x s_e x d_e] + encoder_mask (Optional[torch.Tensor]): Boolean tensor defining a relational matrix between tokens and encoder embeddings. A True value at position i,j means token i can attend to embedding j in the decoder. Mask has shape [b x s x s_e]. Default is None. - input_pos (Optional[Tensor]): Optional tensor which contains the position ids + input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids of each token. During training, this is used to indicate the positions of each token relative to its sample when packed, shape [b x s]. During inference, this indicates the position of the current token. @@ -408,7 +408,7 @@ def forward( KV values for each position. Returns: - Union[Tensor, List[Tensor]]: output tensor with shape [b x s x v] or a list of layer + Union[torch.Tensor, List[torch.Tensor]]: output tensor with shape [b x s x v] or a list of layer output tensors defined by ``output_hidden_states`` with the final output tensor appended to the list. @@ -586,26 +586,26 @@ def reset_caches(self): def forward( self, - tokens: Tensor, + tokens: torch.Tensor, *, - mask: Optional[Tensor] = None, - encoder_input: Optional[Tensor] = None, - encoder_mask: Optional[Tensor] = None, - input_pos: Optional[Tensor] = None, - ) -> Tensor: + mask: Optional[torch.Tensor] = None, + encoder_input: Optional[torch.Tensor] = None, + encoder_mask: Optional[torch.Tensor] = None, + input_pos: Optional[torch.Tensor] = None, + ) -> torch.Tensor: """ Args: - tokens (Tensor): input tensor with shape [b x s] - mask (Optional[Tensor]): Optional boolean tensor which contains the attention mask + tokens (torch.Tensor): input tensor with shape [b x s] + mask (Optional[torch.Tensor]): Optional boolean tensor which contains the attention mask with shape [b x s x s]. This is applied after the query-key multiplication and before the softmax. A value of True in row i and column j means token i attends to token j. A value of False means token i does not attend to token j. If no mask is specified, a causal mask is used by default. Default is None. - encoder_input (Optional[Tensor]): Optional input embeds from the encoder. Shape [b x s_e x d_e] - encoder_mask (Optional[Tensor]): Boolean tensor defining a relational matrix between + encoder_input (Optional[torch.Tensor]): Optional input embeds from the encoder. Shape [b x s_e x d_e] + encoder_mask (Optional[torch.Tensor]): Boolean tensor defining a relational matrix between tokens and encoder embeddings. A True value at position i,j means token i can attend to embedding j in the decoder. Mask has shape [b x s x s_e]. Default is None. - input_pos (Optional[Tensor]): Optional tensor which contains the position ids + input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids of each token. During training, this is used to indicate the positions of each token relative to its sample when packed, shape [b x s]. During inference, this indicates the position of the current token. @@ -617,7 +617,7 @@ def forward( KV values for each position. Returns: - Tensor: output tensor with shape [b x s x v] or a list of layer + torch.Tensor: output tensor with shape [b x s x v] or a list of layer output tensors defined by ``output_hidden_states`` with the final output tensor appended to the list. diff --git a/torchtune/modules/transforms/vision_utils/tile_crop.py b/torchtune/modules/transforms/vision_utils/tile_crop.py index 17e173c3f7..42cbfd0492 100644 --- a/torchtune/modules/transforms/vision_utils/tile_crop.py +++ b/torchtune/modules/transforms/vision_utils/tile_crop.py @@ -20,7 +20,7 @@ def tile_crop(image: torch.Tensor, tile_size: int) -> torch.Tensor: tile_size (int): Size of each tile. Returns: - torch.Tensor: Tensor of shape [num_tiles, channel_size, tile_size, tile_size] + torch.Tensor: torch.Tensor of shape [num_tiles, channel_size, tile_size, tile_size] Examples: >>> image = torch.rand(3, 200, 300) diff --git a/torchtune/modules/vision_transformer.py b/torchtune/modules/vision_transformer.py index 51801a102d..dc1700f7c8 100644 --- a/torchtune/modules/vision_transformer.py +++ b/torchtune/modules/vision_transformer.py @@ -268,8 +268,8 @@ def forward( Notice that to batch it, you will have to pad n_imgs to max_n_imgs and max_num_tiles. Args: - images (torch.Tensor): Tensor with shape (bsz, n_imgs, n_tiles, n_channels, tile_size, tile_size). - aspect_ratio (Optional[torch.Tensor]): Tensor with shape (bsz, n_imgs, 2). If all + images (torch.Tensor): torch.Tensor with shape (bsz, n_imgs, n_tiles, n_channels, tile_size, tile_size). + aspect_ratio (Optional[torch.Tensor]): torch.Tensor with shape (bsz, n_imgs, 2). If all images have a single tile, i.e. they were not tile-cropped, it should be None. Used to calculate the positional embeddings for the tiles. diff --git a/torchtune/training/metric_logging.py b/torchtune/training/metric_logging.py index b432afd266..8c19aa910d 100644 --- a/torchtune/training/metric_logging.py +++ b/torchtune/training/metric_logging.py @@ -10,15 +10,16 @@ from typing import Any, Dict, List, Mapping, Optional, Union +import torch + from numpy import ndarray from omegaconf import DictConfig, OmegaConf -from torch import Tensor from torchtune.utils import get_logger from torchtune.utils._distributed import get_world_size_and_rank from typing_extensions import Protocol -Scalar = Union[Tensor, ndarray, int, float] +Scalar = Union[torch.Tensor, ndarray, int, float] log = get_logger("DEBUG") @@ -261,7 +262,7 @@ class TensorBoardLogger(MetricLoggerInterface): """Logger for use w/ PyTorch's implementation of TensorBoard (https://pytorch.org/docs/stable/tensorboard.html). Args: - log_dir (str): TensorBoard log directory + log_dir (str): torch.TensorBoard log directory organize_logs (bool): If `True`, this class will create a subdirectory within `log_dir` for the current run. Having sub-directories allows you to compare logs across runs. When TensorBoard is passed a logdir at startup, it recursively walks the directory tree rooted at logdir looking for diff --git a/torchtune/utils/_distributed.py b/torchtune/utils/_distributed.py index 0438576a4d..44093663d2 100644 --- a/torchtune/utils/_distributed.py +++ b/torchtune/utils/_distributed.py @@ -97,7 +97,7 @@ def _broadcast_tensor(tensor: torch.Tensor, src: int = 0) -> torch.Tensor: """Broadcasts a tensor from a source to all other processes. Args: - tensor (torch.Tensor): Tensor to broadcast. + tensor (torch.Tensor): torch.Tensor to broadcast. src (int, optional): Source rank. Defaults to 0. Returns: