diff --git a/tests/torchtune/data/test_messages.py b/tests/torchtune/data/test_messages.py index a46cfd9349..86b7d7319f 100644 --- a/tests/torchtune/data/test_messages.py +++ b/tests/torchtune/data/test_messages.py @@ -9,6 +9,7 @@ import pytest from PIL import Image +from tests.common import ASSETS from tests.test_utils import ( assert_dialogue_equal, CHAT_SAMPLE, @@ -24,6 +25,8 @@ validate_messages, ) +PYTORCH_RGB_IMAGE_AS_PIL = Image.open(ASSETS / "rgb_pytorch.png") + class TestMessage: @pytest.fixture @@ -106,6 +109,60 @@ def sample(self): "maybe_output": "hello world", } + @pytest.mark.parametrize( + "input_image, expected_image", + [ + ("rgb_pytorch.png", PYTORCH_RGB_IMAGE_AS_PIL), + (ASSETS / "rgb_pytorch.png", PYTORCH_RGB_IMAGE_AS_PIL), + (PYTORCH_RGB_IMAGE_AS_PIL, PYTORCH_RGB_IMAGE_AS_PIL), + ], + ) + def test_call_with_image(self, sample, input_image, expected_image): + # Add the image to the sample + sample["image"] = input_image + + # Create the transform + transform = InputOutputToMessages( + column_map={ + "input": "maybe_input", + "output": "maybe_output", + "image": "image", + }, + # Need to test if the image_dir is properly joined w/ image + image_dir=ASSETS if isinstance(input_image, str) else None, + ) + actual = transform(sample) + expected = [ + Message( + role="user", + content=[ + {"type": "text", "content": "hello world"}, + {"type": "image", "content": expected_image}, + ], + masked=True, + eot=True, + ), + Message(role="assistant", content="hello world", masked=False, eot=True), + ] + assert_dialogue_equal(actual["messages"], expected) + + def test_call_with_image_fails_when_bad_image_inputs_are_passed(self, sample): + # Construct a bad column_map without an 'image' key + column_map = { + "input": "maybe_input", + "output": "maybe_output", + } + + # Create a transform that expects an image column + with pytest.raises( + ValueError, + match="Please specify an 'image' key in column_map", + ): + transform = InputOutputToMessages( + column_map=column_map, + image_dir=ASSETS, + ) + def test_call(self, sample): transform = InputOutputToMessages( column_map={"input": "maybe_input", "output": "maybe_output"} diff --git a/torchtune/data/_messages.py b/torchtune/data/_messages.py index a6b356b0ca..bbd3ae5981 100644 --- a/torchtune/data/_messages.py +++ b/torchtune/data/_messages.py @@ -170,6 +170,7 @@ class InputOutputToMessages(Transform): Raises: ValueError: If ``column_map`` is provided and ``input`` not in ``column_map``, or ``output`` not in ``column_map``. + ValueError: If ``image_dir`` is provided but ``image`` not in ``column_map``. """ def __init__( @@ -196,6 +197,14 @@ def __init__( else: self.column_map = {"input": "input", "output": "output", "image": "image"} + # Ensure that if a user seems to want to construct a multimodal transform, they provide + # a proper column_mapping + if "image" not in self.column_map.keys() and image_dir is not None: + raise ValueError( + f"image_dir is specified as {image_dir} but 'image' is not in column_map. " + "Please specify an 'image' key in column_map." + ) + self.image_dir = image_dir def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: @@ -206,8 +215,13 @@ def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: if is_multimodal: image_path = sample[self.column_map["image"]] if isinstance(image_path, str): + # Convert image_path to Path obj + image_path = Path(image_path) + + # If image_dir is not None, prepend image_dir to image_path if self.image_dir is not None: image_path = self.image_dir / image_path + # Load if not loaded pil_image = load_image(image_path) else: