diff --git a/recipes/dev/generate_v2.py b/recipes/dev/generate_v2.py index e63ea2dcb0..3ce95a9fdf 100644 --- a/recipes/dev/generate_v2.py +++ b/recipes/dev/generate_v2.py @@ -152,7 +152,10 @@ def generate(self, cfg: DictConfig): batch = {} if is_multimodal_input: batch = padded_collate_tiled_images_and_mask( - [model_inputs], pad_direction="left", pad_max_images=1 + [model_inputs], + pad_direction="left", + pad_max_images=1, + pad_max_tiles=self.model_transform.max_num_tiles, ) batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len] prompt = batch.pop("tokens").to(self._device) diff --git a/tests/torchtune/data/test_collate.py b/tests/torchtune/data/test_collate.py index d17ce2b0ae..02477ad7e5 100644 --- a/tests/torchtune/data/test_collate.py +++ b/tests/torchtune/data/test_collate.py @@ -138,17 +138,14 @@ def test_left_pad_sequence(self, batch): ) imgs, tiles = actual["encoder_input"]["images"].shape[1:3] seq_len = actual["encoder_mask"].shape[-1] - assert imgs * tiles * self.tokens_per_tile == seq_len + assert 5 * 4 * self.tokens_per_tile == seq_len - mask_1 = torch.concat( - [torch.ones(4, 5 * 2), torch.zeros(4, 5 * 3)], dim=1 - ) # pad 3 extra tiles - mask_2 = torch.concat( - [torch.ones(4, 5 * 3), torch.zeros(4, 5 * 2)], dim=1 - ) # pad 2 extra tiles - mask_3 = torch.concat( - [torch.zeros(2, 20), torch.ones(2, 5 * 4)], dim=0 - ) # Left pad text tokens + # pad 3 extra tiles + mask_1 = torch.concat([torch.ones(4, 5 * 2), torch.zeros(4, 5 * 3)], dim=1) + # pad 2 extra tiles + mask_2 = torch.concat([torch.ones(4, 5 * 3), torch.zeros(4, 5 * 2)], dim=1) + # Left pad text tokens + mask_3 = torch.concat([torch.zeros(2, 20), torch.ones(2, 5 * 4)], dim=0) mask_3 = F.pad(mask_3, (0, 5), value=0) # pad 5th tile sample_1 = torch.stack([mask_1, mask_2]) sample_2 = torch.stack([mask_3, torch.zeros(4, 25)]) @@ -163,14 +160,10 @@ def test_left_pad_sequence(self, batch): [ [[[[1.0]]], [[[1.0]]], [[[0.0]]], [[[0.0]]], [[[0.0]]]], [[[[1.0]]], [[[1.0]]], [[[1.0]]], [[[0.0]]], [[[0.0]]]], - [[[[0.0]]], [[[0.0]]], [[[0.0]]], [[[0.0]]], [[[0.0]]]], - [[[[0.0]]], [[[0.0]]], [[[0.0]]], [[[0.0]]], [[[0.0]]]], ], [ [[[[1.0]]], [[[1.0]]], [[[1.0]]], [[[1.0]]], [[[0.0]]]], [[[[0.0]]], [[[0.0]]], [[[0.0]]], [[[0.0]]], [[[0.0]]]], - [[[[0.0]]], [[[0.0]]], [[[0.0]]], [[[0.0]]], [[[0.0]]]], - [[[[0.0]]], [[[0.0]]], [[[0.0]]], [[[0.0]]], [[[0.0]]]], ], ] ),