Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pbontrager committed Oct 29, 2024
1 parent 3685b78 commit e324f89
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 15 deletions.
5 changes: 4 additions & 1 deletion recipes/dev/generate_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,10 @@ def generate(self, cfg: DictConfig):
batch = {}
if is_multimodal_input:
batch = padded_collate_tiled_images_and_mask(
[model_inputs], pad_direction="left", pad_max_images=1
[model_inputs],
pad_direction="left",
pad_max_images=1,
pad_max_tiles=self.model_transform.max_num_tiles,
)
batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len]
prompt = batch.pop("tokens").to(self._device)
Expand Down
21 changes: 7 additions & 14 deletions tests/torchtune/data/test_collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,17 +138,14 @@ def test_left_pad_sequence(self, batch):
)
imgs, tiles = actual["encoder_input"]["images"].shape[1:3]
seq_len = actual["encoder_mask"].shape[-1]
assert imgs * tiles * self.tokens_per_tile == seq_len
assert 5 * 4 * self.tokens_per_tile == seq_len

mask_1 = torch.concat(
[torch.ones(4, 5 * 2), torch.zeros(4, 5 * 3)], dim=1
) # pad 3 extra tiles
mask_2 = torch.concat(
[torch.ones(4, 5 * 3), torch.zeros(4, 5 * 2)], dim=1
) # pad 2 extra tiles
mask_3 = torch.concat(
[torch.zeros(2, 20), torch.ones(2, 5 * 4)], dim=0
) # Left pad text tokens
# pad 3 extra tiles
mask_1 = torch.concat([torch.ones(4, 5 * 2), torch.zeros(4, 5 * 3)], dim=1)
# pad 2 extra tiles
mask_2 = torch.concat([torch.ones(4, 5 * 3), torch.zeros(4, 5 * 2)], dim=1)
# Left pad text tokens
mask_3 = torch.concat([torch.zeros(2, 20), torch.ones(2, 5 * 4)], dim=0)
mask_3 = F.pad(mask_3, (0, 5), value=0) # pad 5th tile
sample_1 = torch.stack([mask_1, mask_2])
sample_2 = torch.stack([mask_3, torch.zeros(4, 25)])
Expand All @@ -163,14 +160,10 @@ def test_left_pad_sequence(self, batch):
[
[[[[1.0]]], [[[1.0]]], [[[0.0]]], [[[0.0]]], [[[0.0]]]],
[[[[1.0]]], [[[1.0]]], [[[1.0]]], [[[0.0]]], [[[0.0]]]],
[[[[0.0]]], [[[0.0]]], [[[0.0]]], [[[0.0]]], [[[0.0]]]],
[[[[0.0]]], [[[0.0]]], [[[0.0]]], [[[0.0]]], [[[0.0]]]],
],
[
[[[[1.0]]], [[[1.0]]], [[[1.0]]], [[[1.0]]], [[[0.0]]]],
[[[[0.0]]], [[[0.0]]], [[[0.0]]], [[[0.0]]], [[[0.0]]]],
[[[[0.0]]], [[[0.0]]], [[[0.0]]], [[[0.0]]], [[[0.0]]]],
[[[[0.0]]], [[[0.0]]], [[[0.0]]], [[[0.0]]], [[[0.0]]]],
],
]
),
Expand Down

0 comments on commit e324f89

Please sign in to comment.