diff --git a/tests/torchtune/data/test_collate.py b/tests/torchtune/data/test_collate.py index c14123fd3d..ca4bccd5d5 100644 --- a/tests/torchtune/data/test_collate.py +++ b/tests/torchtune/data/test_collate.py @@ -10,6 +10,7 @@ import pytest import torch +import torch.nn.functional as F from tests.test_utils import gpu_test from torchtune.data import ( left_pad_sequence, @@ -120,15 +121,20 @@ def test_right_pad_sequence(self, batch): def test_left_pad_sequence(self, batch): actual = padded_collate_tiled_images_and_mask( - batch=batch, padding_idx=0, ignore_idx=-100, pad_direction="left" + batch=batch, + padding_idx=0, + ignore_idx=-100, + pad_direction="left", + pad_max_images=4, ) mask_1 = torch.concat([torch.ones(4, 5 * 2), torch.zeros(4, 10)], dim=1) mask_2 = torch.concat([torch.ones(4, 5 * 3), torch.zeros(4, 5)], dim=1) - mask_3 = torch.concat([torch.ones(2, 5 * 4), torch.zeros(2, 20)], dim=0) + mask_3 = torch.concat([torch.zeros(2, 20), torch.ones(2, 5 * 4)], dim=0) sample_1 = torch.stack([mask_1, mask_2]) sample_2 = torch.stack([mask_3, torch.zeros(4, 20)]) expected_mask = torch.stack([sample_1, sample_2]).view(2, 4, -1) + expected_mask = F.pad(expected_mask, (0, 40), value=0) expected = { "tokens": torch.tensor([[1, 2, 1, 3], [0, 0, 1, 4]]), diff --git a/tests/torchtune/modules/transforms/test_transforms.py b/tests/torchtune/modules/transforms/test_transforms.py index 26acefdcf9..4bdc137dcf 100644 --- a/tests/torchtune/modules/transforms/test_transforms.py +++ b/tests/torchtune/modules/transforms/test_transforms.py @@ -10,6 +10,7 @@ IMAGE_TOKEN_ID = 1 +MAX_NUM_TILES = 4 class TestVisionCrossAttentionMask: @@ -53,6 +54,7 @@ def cross_attn_mask_transform(self, tile_size, patch_size): tile_size=tile_size, patch_size=patch_size, image_token_id=IMAGE_TOKEN_ID, + max_num_tiles=MAX_NUM_TILES, ) def test_get_image_attention_intervals(self, cross_attn_mask_transform, tokens): @@ -78,3 +80,24 @@ def test_call(self, cross_attn_mask_transform, tokens, images, image_num_tokens) assert actual["tokens"] == tokens assert actual["hello"] == dummy_kwargs["hello"] + + def test_inference_call( + self, cross_attn_mask_transform, tokens, images, image_num_tokens + ): + sample = {"tokens": tokens, "encoder_input": {"images": images}} + dummy_kwargs = {"hello": 8} + sample.update(dummy_kwargs) + actual = cross_attn_mask_transform(sample, inference=True) + expected = [ + torch.zeros(len(tokens), image_num_tokens * 2, dtype=torch.bool) + for _ in range(len(images)) + ] + expected[0][2:6, :image_num_tokens] = True + expected[1][3:6, :image_num_tokens] = True + expected[2][6:9, :image_num_tokens] = True + for i in range(len(images)): + torch.testing.assert_close(actual["encoder_mask"][i], expected[i]) + torch.testing.assert_close(actual["encoder_input"]["images"][i], images[i]) + + assert actual["tokens"] == tokens + assert actual["hello"] == dummy_kwargs["hello"] diff --git a/torchtune/data/_collate.py b/torchtune/data/_collate.py index fb9a2fc234..7897b1eee1 100644 --- a/torchtune/data/_collate.py +++ b/torchtune/data/_collate.py @@ -3,7 +3,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -223,6 +223,7 @@ def padded_collate_tiled_images_and_mask( padding_idx: int = 0, ignore_idx: int = CROSS_ENTROPY_IGNORE_IDX, pad_direction: str = "right", + pad_max_images: Optional[int] = None, ) -> Dict[str, torch.Tensor]: """Pad a batch of text sequences, tiled image tensors, aspect ratios, and cross attention masks. This can be used for both training and inference. @@ -259,6 +260,8 @@ def padded_collate_tiled_images_and_mask( :func:`torch.nn.utils.rnn.pad_sequence`, otherwise if ``pad_direction="left"``, we use :func:`torchtune.data.left_pad_sequence`. For training, we typically want to pad from the right. For inference, we typically want to pad from the left. Defaults to "right". + pad_max_images (Optional[int]): Maximum number of images to pad to. If None, will pad to the largest number of images + in the batch. Defaults to None. Returns: Dict[str, Tensor]: Collated tokens, labels, images, encoder_mask, aspect_ratio tensors. @@ -370,14 +373,28 @@ def padded_collate_tiled_images_and_mask( text_seq_len, image_seq_len = mask.shape tokens_per_tile = image_seq_len // n_tiles padding_tiles = max_num_tiles - n_tiles - padding_text = max_seq_len - text_seq_len + right_padding_text = ( + max_seq_len - text_seq_len if pad_direction == "right" else 0 + ) + left_padding_text = ( + max_seq_len - text_seq_len if pad_direction == "left" else 0 + ) + # Image should now have shape (max_num_tiles, c, h, w) padded_image = F.pad(image, (0, 0, 0, 0, 0, 0, 0, padding_tiles), value=0) # Mask should now have shape (max_seq_len, max_image_seq_len), where # max_image_seq_len = max_num_tiles * tokens_per_tile padded_mask = F.pad( - mask, (0, padding_tiles * tokens_per_tile, 0, padding_text), value=0 + mask, + ( + 0, + padding_tiles * tokens_per_tile, + left_padding_text, + right_padding_text, + ), + value=0, ) + sample_images.append(padded_image) sample_masks.append(padded_mask) # Stack multiple images and masks per sample in num_images dimension @@ -396,6 +413,11 @@ def padded_collate_tiled_images_and_mask( # Concatenate masks for multiple images across image_seq_len dimension concat_masks = collated_masks.view(bsz, max_seq_len, -1) + if pad_max_images is not None: + _, _, img_seq = concat_masks.shape + concat_masks = F.pad( + concat_masks, (0, pad_max_images * image_seq_len - img_seq) + ) batch_dict = { "tokens": collated_text["tokens"], diff --git a/torchtune/models/flamingo/_transform.py b/torchtune/models/flamingo/_transform.py index 94e5b21539..d88111c935 100644 --- a/torchtune/models/flamingo/_transform.py +++ b/torchtune/models/flamingo/_transform.py @@ -40,7 +40,6 @@ class FlamingoTransform(ModelTokenizer, Transform): Llama3 special tokens. max_seq_len (Optional[int]): maximum sequence length for tokenizing a single list of messages, after which the input will be truncated. Default is None. - encoder_max_seq_len (Optional[int]): maximum sequence length for the encoder input. Default is None. image_mean (Optional[Tuple[float, float, float]]): Mean values of each channel, used for normalization. image_std (Optional[Tuple[float, float, float]]): Standard deviations for each channel, used for normalization. prompt_template (Optional[PromptTemplate]): template used to format the messages based on their role. This is used @@ -71,7 +70,6 @@ def __init__( max_num_tiles: int = 4, special_tokens: Optional[Dict[str, int]] = None, max_seq_len: Optional[int] = None, - encoder_max_seq_len: Optional[int] = None, image_mean: Optional[Tuple[float, float, float]] = None, image_std: Optional[Tuple[float, float, float]] = None, prompt_template: Optional[PromptTemplate] = None, @@ -96,7 +94,7 @@ def __init__( tile_size=tile_size, patch_size=patch_size, image_token_id=self.tokenizer.image_id, - encoder_max_seq_len=encoder_max_seq_len, + max_num_tiles=max_num_tiles, ) self.stop_tokens = self.tokenizer.stop_tokens diff --git a/torchtune/modules/transforms/_transforms.py b/torchtune/modules/transforms/_transforms.py index 4e479c5f5f..006f224144 100644 --- a/torchtune/modules/transforms/_transforms.py +++ b/torchtune/modules/transforms/_transforms.py @@ -57,8 +57,8 @@ class VisionCrossAttentionMask(Transform): E.g. for patch_size = 40, a tile of shape (400, 400) will have 10x10 grid of patches with shape (40, 40) each. image_token_id (int): Token ID of the image special token. - encoder_max_seq_len (Optional[int]): Maximum sequence length of the vision sequence, used to - pad mask during inference. Defaults to None. + max_num_tiles (Optional[int]): Maximum number of tiles in an image, used to + pad mask during inference. Defaults to None """ def __init__( @@ -66,13 +66,12 @@ def __init__( tile_size: int, patch_size: int, image_token_id: int, - encoder_max_seq_len: Optional[int] = None, + max_num_tiles: Optional[int] = None, ): patch_grid_size = tile_size // patch_size self.patches_per_tile = patch_grid_size**2 self.image_token_id = image_token_id - - self.encoder_max_seq_len = encoder_max_seq_len + self.max_num_tiles = max_num_tiles def _get_image_attention_intervals(self, tokens: List[int]) -> List[List[int]]: """ @@ -164,7 +163,9 @@ def __call__( # which can vary based on number of tiles since they are not yet tile padded. # The masks are padded and concatenated together in the batch collator text_seq_len = len(tokens) - max_image_size = self.encoder_max_seq_len if inference else None + max_image_size = None + if inference and self.max_num_tiles is not None: + max_image_size = self.max_num_tiles * (self.patches_per_tile + 1) masks = [] for image_num, interval in enumerate(intervals): # Identify what part of text sequence should be attended