Skip to content

Commit

Permalink
Fixes for MM Masking and Collation (#1601)
Browse files Browse the repository at this point in the history
  • Loading branch information
pbontrager authored Sep 16, 2024
1 parent 7045e96 commit d46aaa6
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 14 deletions.
10 changes: 8 additions & 2 deletions tests/torchtune/data/test_collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import pytest
import torch
import torch.nn.functional as F
from tests.test_utils import gpu_test
from torchtune.data import (
left_pad_sequence,
Expand Down Expand Up @@ -120,15 +121,20 @@ def test_right_pad_sequence(self, batch):

def test_left_pad_sequence(self, batch):
actual = padded_collate_tiled_images_and_mask(
batch=batch, padding_idx=0, ignore_idx=-100, pad_direction="left"
batch=batch,
padding_idx=0,
ignore_idx=-100,
pad_direction="left",
pad_max_images=4,
)

mask_1 = torch.concat([torch.ones(4, 5 * 2), torch.zeros(4, 10)], dim=1)
mask_2 = torch.concat([torch.ones(4, 5 * 3), torch.zeros(4, 5)], dim=1)
mask_3 = torch.concat([torch.ones(2, 5 * 4), torch.zeros(2, 20)], dim=0)
mask_3 = torch.concat([torch.zeros(2, 20), torch.ones(2, 5 * 4)], dim=0)
sample_1 = torch.stack([mask_1, mask_2])
sample_2 = torch.stack([mask_3, torch.zeros(4, 20)])
expected_mask = torch.stack([sample_1, sample_2]).view(2, 4, -1)
expected_mask = F.pad(expected_mask, (0, 40), value=0)

expected = {
"tokens": torch.tensor([[1, 2, 1, 3], [0, 0, 1, 4]]),
Expand Down
23 changes: 23 additions & 0 deletions tests/torchtune/modules/transforms/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@


IMAGE_TOKEN_ID = 1
MAX_NUM_TILES = 4


class TestVisionCrossAttentionMask:
Expand Down Expand Up @@ -53,6 +54,7 @@ def cross_attn_mask_transform(self, tile_size, patch_size):
tile_size=tile_size,
patch_size=patch_size,
image_token_id=IMAGE_TOKEN_ID,
max_num_tiles=MAX_NUM_TILES,
)

def test_get_image_attention_intervals(self, cross_attn_mask_transform, tokens):
Expand All @@ -78,3 +80,24 @@ def test_call(self, cross_attn_mask_transform, tokens, images, image_num_tokens)

assert actual["tokens"] == tokens
assert actual["hello"] == dummy_kwargs["hello"]

def test_inference_call(
self, cross_attn_mask_transform, tokens, images, image_num_tokens
):
sample = {"tokens": tokens, "encoder_input": {"images": images}}
dummy_kwargs = {"hello": 8}
sample.update(dummy_kwargs)
actual = cross_attn_mask_transform(sample, inference=True)
expected = [
torch.zeros(len(tokens), image_num_tokens * 2, dtype=torch.bool)
for _ in range(len(images))
]
expected[0][2:6, :image_num_tokens] = True
expected[1][3:6, :image_num_tokens] = True
expected[2][6:9, :image_num_tokens] = True
for i in range(len(images)):
torch.testing.assert_close(actual["encoder_mask"][i], expected[i])
torch.testing.assert_close(actual["encoder_input"]["images"][i], images[i])

assert actual["tokens"] == tokens
assert actual["hello"] == dummy_kwargs["hello"]
28 changes: 25 additions & 3 deletions torchtune/data/_collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -223,6 +223,7 @@ def padded_collate_tiled_images_and_mask(
padding_idx: int = 0,
ignore_idx: int = CROSS_ENTROPY_IGNORE_IDX,
pad_direction: str = "right",
pad_max_images: Optional[int] = None,
) -> Dict[str, torch.Tensor]:
"""Pad a batch of text sequences, tiled image tensors, aspect ratios,
and cross attention masks. This can be used for both training and inference.
Expand Down Expand Up @@ -259,6 +260,8 @@ def padded_collate_tiled_images_and_mask(
:func:`torch.nn.utils.rnn.pad_sequence`, otherwise if ``pad_direction="left"``,
we use :func:`torchtune.data.left_pad_sequence`. For training, we typically want to pad from the right.
For inference, we typically want to pad from the left. Defaults to "right".
pad_max_images (Optional[int]): Maximum number of images to pad to. If None, will pad to the largest number of images
in the batch. Defaults to None.
Returns:
Dict[str, Tensor]: Collated tokens, labels, images, encoder_mask, aspect_ratio tensors.
Expand Down Expand Up @@ -370,14 +373,28 @@ def padded_collate_tiled_images_and_mask(
text_seq_len, image_seq_len = mask.shape
tokens_per_tile = image_seq_len // n_tiles
padding_tiles = max_num_tiles - n_tiles
padding_text = max_seq_len - text_seq_len
right_padding_text = (
max_seq_len - text_seq_len if pad_direction == "right" else 0
)
left_padding_text = (
max_seq_len - text_seq_len if pad_direction == "left" else 0
)

# Image should now have shape (max_num_tiles, c, h, w)
padded_image = F.pad(image, (0, 0, 0, 0, 0, 0, 0, padding_tiles), value=0)
# Mask should now have shape (max_seq_len, max_image_seq_len), where
# max_image_seq_len = max_num_tiles * tokens_per_tile
padded_mask = F.pad(
mask, (0, padding_tiles * tokens_per_tile, 0, padding_text), value=0
mask,
(
0,
padding_tiles * tokens_per_tile,
left_padding_text,
right_padding_text,
),
value=0,
)

sample_images.append(padded_image)
sample_masks.append(padded_mask)
# Stack multiple images and masks per sample in num_images dimension
Expand All @@ -396,6 +413,11 @@ def padded_collate_tiled_images_and_mask(

# Concatenate masks for multiple images across image_seq_len dimension
concat_masks = collated_masks.view(bsz, max_seq_len, -1)
if pad_max_images is not None:
_, _, img_seq = concat_masks.shape
concat_masks = F.pad(
concat_masks, (0, pad_max_images * image_seq_len - img_seq)
)

batch_dict = {
"tokens": collated_text["tokens"],
Expand Down
4 changes: 1 addition & 3 deletions torchtune/models/flamingo/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ class FlamingoTransform(ModelTokenizer, Transform):
Llama3 special tokens.
max_seq_len (Optional[int]): maximum sequence length for tokenizing a single list of messages,
after which the input will be truncated. Default is None.
encoder_max_seq_len (Optional[int]): maximum sequence length for the encoder input. Default is None.
image_mean (Optional[Tuple[float, float, float]]): Mean values of each channel, used for normalization.
image_std (Optional[Tuple[float, float, float]]): Standard deviations for each channel, used for normalization.
prompt_template (Optional[PromptTemplate]): template used to format the messages based on their role. This is used
Expand Down Expand Up @@ -71,7 +70,6 @@ def __init__(
max_num_tiles: int = 4,
special_tokens: Optional[Dict[str, int]] = None,
max_seq_len: Optional[int] = None,
encoder_max_seq_len: Optional[int] = None,
image_mean: Optional[Tuple[float, float, float]] = None,
image_std: Optional[Tuple[float, float, float]] = None,
prompt_template: Optional[PromptTemplate] = None,
Expand All @@ -96,7 +94,7 @@ def __init__(
tile_size=tile_size,
patch_size=patch_size,
image_token_id=self.tokenizer.image_id,
encoder_max_seq_len=encoder_max_seq_len,
max_num_tiles=max_num_tiles,
)

self.stop_tokens = self.tokenizer.stop_tokens
Expand Down
13 changes: 7 additions & 6 deletions torchtune/modules/transforms/_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,22 +57,21 @@ class VisionCrossAttentionMask(Transform):
E.g. for patch_size = 40, a tile of shape (400, 400) will have 10x10 grid of patches
with shape (40, 40) each.
image_token_id (int): Token ID of the image special token.
encoder_max_seq_len (Optional[int]): Maximum sequence length of the vision sequence, used to
pad mask during inference. Defaults to None.
max_num_tiles (Optional[int]): Maximum number of tiles in an image, used to
pad mask during inference. Defaults to None
"""

def __init__(
self,
tile_size: int,
patch_size: int,
image_token_id: int,
encoder_max_seq_len: Optional[int] = None,
max_num_tiles: Optional[int] = None,
):
patch_grid_size = tile_size // patch_size
self.patches_per_tile = patch_grid_size**2
self.image_token_id = image_token_id

self.encoder_max_seq_len = encoder_max_seq_len
self.max_num_tiles = max_num_tiles

def _get_image_attention_intervals(self, tokens: List[int]) -> List[List[int]]:
"""
Expand Down Expand Up @@ -164,7 +163,9 @@ def __call__(
# which can vary based on number of tiles since they are not yet tile padded.
# The masks are padded and concatenated together in the batch collator
text_seq_len = len(tokens)
max_image_size = self.encoder_max_seq_len if inference else None
max_image_size = None
if inference and self.max_num_tiles is not None:
max_image_size = self.max_num_tiles * (self.patches_per_tile + 1)
masks = []
for image_num, interval in enumerate(intervals):
# Identify what part of text sequence should be attended
Expand Down

0 comments on commit d46aaa6

Please sign in to comment.