Skip to content

Commit

Permalink
Fix bug in loading multimodal datasets and update tests accordingly (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
joecummings authored Dec 4, 2024
1 parent e9b9ea5 commit 9b41f49
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 0 deletions.
57 changes: 57 additions & 0 deletions tests/torchtune/data/test_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pytest

from PIL import Image
from tests.common import ASSETS
from tests.test_utils import (
assert_dialogue_equal,
CHAT_SAMPLE,
Expand All @@ -24,6 +25,8 @@
validate_messages,
)

PYTORCH_RGB_IMAGE_AS_PIL = Image.open(ASSETS / "rgb_pytorch.png")


class TestMessage:
@pytest.fixture
Expand Down Expand Up @@ -106,6 +109,60 @@ def sample(self):
"maybe_output": "hello world",
}

@pytest.mark.parametrize(
"input_image, expected_image",
[
("rgb_pytorch.png", PYTORCH_RGB_IMAGE_AS_PIL),
(ASSETS / "rgb_pytorch.png", PYTORCH_RGB_IMAGE_AS_PIL),
(PYTORCH_RGB_IMAGE_AS_PIL, PYTORCH_RGB_IMAGE_AS_PIL),
],
)
def test_call_with_image(self, sample, input_image, expected_image):
# Add the image to the sample
sample["image"] = input_image

# Create the transform
transform = InputOutputToMessages(
column_map={
"input": "maybe_input",
"output": "maybe_output",
"image": "image",
},
# Need to test if the image_dir is properly joined w/ image
image_dir=ASSETS if isinstance(input_image, str) else None,
)
actual = transform(sample)
expected = [
Message(
role="user",
content=[
{"type": "text", "content": "hello world"},
{"type": "image", "content": expected_image},
],
masked=True,
eot=True,
),
Message(role="assistant", content="hello world", masked=False, eot=True),
]
assert_dialogue_equal(actual["messages"], expected)

def test_call_with_image_fails_when_bad_image_inputs_are_passed(self, sample):
# Construct a bad column_map without an 'image' key
column_map = {
"input": "maybe_input",
"output": "maybe_output",
}

# Create a transform that expects an image column
with pytest.raises(
ValueError,
match="Please specify an 'image' key in column_map",
):
transform = InputOutputToMessages(
column_map=column_map,
image_dir=ASSETS,
)

def test_call(self, sample):
transform = InputOutputToMessages(
column_map={"input": "maybe_input", "output": "maybe_output"}
Expand Down
14 changes: 14 additions & 0 deletions torchtune/data/_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ class InputOutputToMessages(Transform):
Raises:
ValueError: If ``column_map`` is provided and ``input`` not in ``column_map``, or
``output`` not in ``column_map``.
ValueError: If ``image_dir`` is provided but ``image`` not in ``column_map``.
"""

def __init__(
Expand All @@ -196,6 +197,14 @@ def __init__(
else:
self.column_map = {"input": "input", "output": "output", "image": "image"}

# Ensure that if a user seems to want to construct a multimodal transform, they provide
# a proper column_mapping
if "image" not in self.column_map.keys() and image_dir is not None:
raise ValueError(
f"image_dir is specified as {image_dir} but 'image' is not in column_map. "
"Please specify an 'image' key in column_map."
)

self.image_dir = image_dir

def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
Expand All @@ -206,8 +215,13 @@ def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
if is_multimodal:
image_path = sample[self.column_map["image"]]
if isinstance(image_path, str):
# Convert image_path to Path obj
image_path = Path(image_path)

# If image_dir is not None, prepend image_dir to image_path
if self.image_dir is not None:
image_path = self.image_dir / image_path

# Load if not loaded
pil_image = load_image(image_path)
else:
Expand Down

0 comments on commit 9b41f49

Please sign in to comment.