diff --git a/torchtune/models/clip/_transform.py b/torchtune/models/clip/_transform.py index a9b60624ff..f0f7f2c3c5 100644 --- a/torchtune/models/clip/_transform.py +++ b/torchtune/models/clip/_transform.py @@ -159,10 +159,9 @@ def __call__( assert isinstance(image, Image.Image), "Input image must be a PIL image." # Make image torch.tensor((3, H, W), dtype=dtype), 0<=values<=1 - if hasattr(image, "mode") and image.mode == "RGBA": + if image.mode != "RGB": image = image.convert("RGB") image = F.to_image(image) - image = F.grayscale_to_rgb_image(image) image = F.to_dtype(image, dtype=self.dtype, scale=True) # Find the best canvas to fit the image without distortion