diff --git a/tests/torchtune/models/clip/test_clip_image_transform.py b/tests/torchtune/models/clip/test_clip_image_transform.py index 0b99a5df1b..58920ff32f 100644 --- a/tests/torchtune/models/clip/test_clip_image_transform.py +++ b/tests/torchtune/models/clip/test_clip_image_transform.py @@ -10,7 +10,7 @@ import torch -from torchtune.models.clip._transforms import CLIPImageTransform +from torchtune.models.clip._transforms import CLIPImageTransform, ImageTransformConfig class TestPipelines: @@ -41,7 +41,7 @@ class TestPipelines: ) def test_clip_image_transform(self, params): - image_transform = CLIPImageTransform( + config = ImageTransformConfig( image_mean=None, image_std=None, tile_size=224, @@ -49,8 +49,12 @@ def test_clip_image_transform(self, params): max_num_tiles=4, resample="bilinear", resize_to_max_canvas=params["resize_to_max_canvas"], + max_upscaling_size=None, + normalize=False, ) + image_transform = CLIPImageTransform(config=config) + image_size = params["image_size"] # Create a random image diff --git a/tests/torchtune/modules/transforms/test_tile_crop.py b/tests/torchtune/modules/transforms/test_tile_crop.py index 7afde495a3..78fad28196 100644 --- a/tests/torchtune/modules/transforms/test_tile_crop.py +++ b/tests/torchtune/modules/transforms/test_tile_crop.py @@ -8,7 +8,7 @@ import torch -from torchtune.modules.transforms import tile_crop +from torchtune.modules.transforms import TileCrop class TestTransforms: @@ -55,7 +55,7 @@ def test_tile_crop(self, params): image = torch.rand(*image_size) # Create a random image tensor if status == "Passed": - tiles = tile_crop(image, tile_size) + tiles = TileCrop()(image, tile_size) expected_output_shape = params["expected_output_shape"] assert ( tiles.shape == expected_output_shape @@ -73,7 +73,7 @@ def test_tile_crop(self, params): elif status == "Failed": with pytest.raises(Exception) as exc_info: - tile_crop(image, tile_size) + TileCrop()(image, tile_size) expected_error = params["error"] actual_error = str(exc_info.value) assert ( diff --git a/torchtune/models/clip/_transforms.py b/torchtune/models/clip/_transforms.py index 781e7abb91..e2f7af0bfd 100644 --- a/torchtune/models/clip/_transforms.py +++ b/torchtune/models/clip/_transforms.py @@ -5,24 +5,131 @@ # LICENSE file in the root directory of this source tree. import logging +import math +from dataclasses import dataclass from typing import Any, List, Mapping, Optional, Tuple import torch -import torchvision from PIL import Image from torchtune.modules.transforms import ( find_supported_resolutions, get_canvas_best_fit, - resize_with_pad, - tile_crop, + TileCrop, ) +from torchvision.transforms import v2 +from torchvision.transforms._functional_tensor import resize from torchvision.transforms.v2 import functional as F logger = logging.getLogger(__name__) +@dataclass +class ImageTransformConfig: + """ + image_mean (Optional[List[float]]): Mean values of each channel, used for normalization. + Should be the same used for the pre-trained model. If None, no normalization is performed. Default None. + image_std (Optional[List[float]]): Standard deviation values of each channel, used for normalization. + Should be the same used for the pre-trained model. If None, no normalization is performed. Default None. + possible_resolutions (Optional[List[Tuple[int, int]]]): List of possible resolutions as tuples (height, width). + where each tuple represents a possible canvas to fit the image into when calling ``get_canvas_best_fit``. + If None, this will be calculated using max_num_tiles and tile_size. Default None. + tile_size (int): Size of the tiles to divide the image into. Default 224. + max_num_tiles (Optional[int]): Only used if possible_resolutions is NOT given. + Maximum number of tiles to break an image into. + This will be used to generate possible_resolutions, + e.g. [(224, 224), (224, 448), (448, 224)] if max_num_tiles = 2 and tile_size = 224. + Default 4. + resample (str): Resampling method used when resizing images. Supports any enum of + ``torchvision.transforms.InterpolationMode``, e.g. "nearest", "nearest_exact", "bilinear", "bicubic". + Default 'bilinear'. + resize_to_max_canvas (bool): If true, the image will be upscaled without distortion to fit the largest possible + resolution from possible_resolutions. + If False, it will pick the resolution that minimizes downscaling, including no downscaling at all. + In this case, the image will only be upscaled if it's size < tile_size. Default False. + max_upscaling_size (int): Limit the upscaling performed on the image. + normalize (int): If true, image will be normalized using image_mean and image_std. + max_image_size (int): Max image size, used to bound the tensor sizes for export. + """ + + image_mean: Optional[List[float]] = None + image_std: Optional[List[float]] = None + possible_resolutions: Optional[List[Tuple[int, int]]] = None + tile_size: int = 224 + max_num_tiles: Optional[int] = 4 + resample: str = "bilinear" + resize_to_max_canvas: bool = False + max_upscaling_size: Optional[int] = None + normalize: bool = True + max_image_size: int = 5000 + + +class CLIPImageTransformCore(torch.nn.Module): + def __init__(self, config: ImageTransformConfig): + super().__init__() + self.config = config + + """ + Performs the core transformations involved in CLIPImageTransform; + 1. Resize the image to target_size. + 2. Pad the image to canvas_size. + 3. Normalize the image using config.image_mean and config.image_std. + 4. Reshape the image tensor into [n, channels, config.tile_size, config.tile_size]. + + Args: + image (torch.Tensor): image as a 3D tensor in form [C, H, W]. + target_size (torch.Tensor): tensor of shape [1, 2] containing the target_height and target_width for resize. + canvas_size (torch.Tensor): tensor of shape [1, 2] containing the canvas_height and canvas_width for padding. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Image tensor of shape [n, channels, config.tile_size, config.tile_size] + and aspect ratio tensor of shape [1, 2]. + """ + + def forward( + self, image: torch.Tensor, target_size: torch.Tensor, canvas_size: torch.Tensor + ): + # Resize. + target_h, target_w = target_size.tolist() + torch._check(target_h > 1) + torch._check(target_h <= self.config.max_image_size) + torch._check(target_w > 1) + torch._check(target_w <= self.config.max_image_size) + + image = resize( + image, + size=[target_h, target_w], + interpolation=self.config.resample, + antialias=True, + ) + + # Pad. + canvas_h, canvas_w = canvas_size.tolist() + torch._check(canvas_h > 1) + torch._check(canvas_h <= self.config.max_image_size) + torch._check(canvas_w > 1) + torch._check(canvas_w <= self.config.max_image_size) + sizes = [3, canvas_h, canvas_w] + + padding = [0, canvas_h - target_h, 0, canvas_w - target_w] + output = v2.Pad(padding=padding)(image) + + # Normalize. + if self.config.normalize: + output = v2.functional.normalize( + output, self.config.image_mean, self.config.image_std + ) + + # Reshape. + tiles = TileCrop()(output, self.config.tile_size) + + # Calculate aspect ratio. + aspect_ratio = canvas_size // self.config.tile_size + + return tiles, aspect_ratio + + class CLIPImageTransform: """ This class accepts images of any size and dynamically resizes, pads, normalizes and tiles it @@ -36,42 +143,23 @@ class CLIPImageTransform: For example, if an input image is of size 300x800, and we want to allow a maximum of 16 image tiles, with side 224px, then: - If ``resize_to_max_canvas=False``, then: - best_resolution = (448, 896) -> smallest canvas, up to 16 tiles, that doesn't require downscaling + If ``config.resize_to_max_canvas=False``, then: + canvas_size = (448, 896) -> smallest canvas, up to 16 tiles, that doesn't require downscaling image is NOT resized image is padded (300, 800) -> 448,896 Image is tiled 2x4, for a final output shape of (8, 3, 224, 224) - If ``resize_to_max_canvas=True``, then: - best_resolution = (448, 1344) # canvas that allows maximum upscaling, with minimum padding, up to 16 tiles + If ``config.resize_to_max_canvas=True``, then: + canvas_size = (448, 1344) # canvas that allows maximum upscaling, with minimum padding, up to 16 tiles image is resized without distortion (300,800) -> (448, 1194) #448 is the limiting side for the resize image is padded (448, 1194) -> (448, 1344) Image is tiled 2x5, for a final output shape of (10, 3, 224, 224) Args: - image_mean (Optional[List[float]]): Mean values of each channel, used for normalization. - Should be the same used for the pre-trained model. If None, no normalization is performed. Default None. - image_std (Optional[List[float]]): Standard deviation values of each channel, used for normalization. - Should be the same used for the pre-trained model. If None, no normalization is performed. Default None. - possible_resolutions (Optional[List[Tuple[int, int]]]): List of possible resolutions as tuples (height, width). - where each tuple represents a possible canvas to fit the image into when calling ``get_canvas_best_fit``. - If None, this will be calculated using max_num_tiles and tile_size. Default None. - tile_size (int): Size of the tiles to divide the image into. Default 224. - max_num_tiles (Optional[int]): Only used if possible_resolutions is NOT given. - Maximum number of tiles to break an image into. - This will be used to generate possible_resolutions, - e.g. [(224, 224), (224, 448), (448, 224)] if max_num_tiles = 2 and tile_size = 224. - Default 4. - resample (str): Resampling method used when resizing images. Supports any enum of - ``torchvision.transforms.InterpolationMode``, e.g. "nearest", "nearest_exact", "bilinear", "bicubic". - Default 'bilinear'. - resize_to_max_canvas (bool): "If True, the image will be upscaled without distortion to fit the largest possible - resolution from possible_resolutions. - If False, it will pick the resolution that minimizes downscaling, including no downscaling at all. - In this case, the image will only be upscaled if it's size < tile_size. Default False. + config (Optional[ImageTransformConfig]): config values for the image transform. If None, uses the default values. Examples: - >>> image_transform = CLIPImageTransform( + >>> config = ImageTransformConfig( ... image_mean=None, ... image_std=None, ... tile_size=224, @@ -80,6 +168,7 @@ class CLIPImageTransform: ... resample="bilinear", ... resize_to_max_canvas=True, ...) + >>> image_transform = CLIPImageTransform(config=config) >>> # create random image >>> image = (np.random.rand(100,200,3) * 255).astype(np.uint8) >>> image = PIL.Image.fromarray(image) @@ -92,48 +181,43 @@ class CLIPImageTransform: def __init__( self, - image_mean: Optional[List[float]] = None, - image_std: Optional[List[float]] = None, - possible_resolutions: Optional[List[Tuple[int, int]]] = None, - tile_size: int = 224, - max_num_tiles: Optional[int] = 4, - resample: str = "bilinear", - resize_to_max_canvas: bool = False, + config: Optional[ImageTransformConfig] = None, ) -> None: + if config is None: + config = ImageTransformConfig() # get_canvas_best_fit assert ( - possible_resolutions is not None or max_num_tiles is not None - ), f"Either possible_resolutions or max_num_tiles must be given. Got {possible_resolutions=} and {max_num_tiles=}" + config.possible_resolutions is not None or config.max_num_tiles is not None + ), ( + "Either possible_resolutions or max_num_tiles must be given." + + f"Got {config.possible_resolutions=} and {config.max_num_tiles=}" + ) # If possible_resolutions are not given, then calculate possible ones based on max_num_tiles - if not possible_resolutions and max_num_tiles: + if not config.possible_resolutions and config.max_num_tiles: possible_resolutions = find_supported_resolutions( - max_num_tiles=max_num_tiles, tile_size=tile_size + max_num_tiles=config.max_num_tiles, tile_size=config.tile_size + ) + config.possible_resolutions = torch.tensor(possible_resolutions).reshape( + -1, 2 ) - else: - possible_resolutions = possible_resolutions - self.possible_resolutions = torch.tensor(possible_resolutions).reshape(-1, 2) logger.info( - f"Found possible_resolutions: {self.possible_resolutions}. Will fit the images into the canvas with best fit." + f"Found possible_resolutions: {config.possible_resolutions}. Will fit the images into the canvas with best fit." ) - self.resize_to_max_canvas = resize_to_max_canvas - # normalize - assert (image_mean is None) == ( - image_std is None - ), f"Need to provide both or none of image_mean and image_std. Got {image_mean=} and {image_std=}" - self.image_mean = image_mean - self.image_std = image_std + assert (config.image_mean is None) == ( + config.image_std is None + ), f"Need to provide both or none of image_mean and image_std. Got {config.image_mean=} and {config.image_std=}" # resize_with_pad - self.max_upscaling_size = None if resize_to_max_canvas else tile_size - self.resample = torchvision.transforms.InterpolationMode[resample.upper()] + config.max_upscaling_size = ( + None if config.resize_to_max_canvas else config.tile_size + ) - # tile_crop - self.tile_size = tile_size + self.config = config def __call__(self, *, image: Image.Image, **kwargs) -> Mapping[str, Any]: @@ -144,35 +228,47 @@ def __call__(self, *, image: Image.Image, **kwargs) -> Mapping[str, Any]: F.grayscale_to_rgb_image(F.to_image(image)), scale=True ) - # Find the best canvas to fit the image without distortion - best_resolution = get_canvas_best_fit( + # Find the best canvas to fit the image without distortion. + # This finds the resolution of the best tile arrangement given the + # image dimensions, aspect ratio, and config.num_tiles. + canvas_size = get_canvas_best_fit( image=image_tensor, - possible_resolutions=self.possible_resolutions, - resize_to_max_canvas=self.resize_to_max_canvas, + possible_resolutions=self.config.possible_resolutions, + resize_to_max_canvas=self.config.resize_to_max_canvas, ) - # resize without distortion + pad to fit best_resolution - image_tensor = resize_with_pad( - image=image_tensor, - target_size=best_resolution, - resample=self.resample, - max_upscaling_size=self.max_upscaling_size, - ) + image_height, image_width = image_tensor.shape[-2:] - # Normalize - if self.image_mean and self.image_std: - image_tensor = F.normalize( - image_tensor, mean=self.image_mean, std=self.image_std + # If target_size requires upscaling, we might want to limit the upscaling to max_upscaling_size. + if self.config.max_upscaling_size is not None: + target_height = min( + max(image_height, self.config.max_upscaling_size), canvas_size[0] + ) + target_width = min( + max(image_width, self.config.max_upscaling_size), canvas_size[1] ) + target_size = (target_height, target_width) + else: + target_size = canvas_size + + # Calculate the target size; the largest aspect ratio preserving size that can fit within the + # canvas size. + scale_h = target_size[0] / image_height + scale_w = target_size[1] / image_width - # Divide the image into equally sized tiles - image_tensor = tile_crop(image=image_tensor, tile_size=self.tile_size) + new_target_height = min(math.floor(image_height * scale_w), target_size[0]) + new_target_width = min(math.floor(image_width * scale_h), target_size[1]) - aspect_ratio = torch.tensor(best_resolution).reshape(-1) // self.tile_size + # Call ClipImageTransformCore to perform resize, pad and reshape transforms. + tiles, aspect_ratio = CLIPImageTransformCore(self.config)( + image=image_tensor, + target_size=torch.tensor([new_target_height, new_target_width]), + canvas_size=torch.tensor([canvas_size[0], canvas_size[1]]), + ) kwargs.update( { - "image": image_tensor, + "image": tiles, "aspect_ratio": aspect_ratio, } ) diff --git a/torchtune/modules/transforms/__init__.py b/torchtune/modules/transforms/__init__.py index c317e7d7ce..d4629c90e2 100644 --- a/torchtune/modules/transforms/__init__.py +++ b/torchtune/modules/transforms/__init__.py @@ -12,13 +12,13 @@ from torchtune.modules.transforms.vision_utils.resize_with_pad import ( # noqa resize_with_pad, ) -from torchtune.modules.transforms.vision_utils.tile_crop import tile_crop # noqa +from torchtune.modules.transforms.vision_utils.tile_crop import TileCrop # noqa __all__ = [ + "TileCrop", "Transform", "get_canvas_best_fit", "resize_with_pad", - "tile_crop", "find_supported_resolutions", "VisionCrossAttentionMask", ] diff --git a/torchtune/modules/transforms/vision_utils/tile_crop.py b/torchtune/modules/transforms/vision_utils/tile_crop.py index 17e173c3f7..d2f37b5e6f 100644 --- a/torchtune/modules/transforms/vision_utils/tile_crop.py +++ b/torchtune/modules/transforms/vision_utils/tile_crop.py @@ -11,7 +11,10 @@ logger = logging.getLogger(__name__) -def tile_crop(image: torch.Tensor, tile_size: int) -> torch.Tensor: +class TileCrop(torch.nn.Module): + def __init__(self): + super().__init__() + """ Divides a tensor into equally sized tiles. The tensor should be divisible by tile_size. @@ -24,36 +27,37 @@ def tile_crop(image: torch.Tensor, tile_size: int) -> torch.Tensor: Examples: >>> image = torch.rand(3, 200, 300) - >>> tiles = tile_crop(image, tile_size=50) + >>> tiles = TileCrop(image, tile_size=50) >>> tiles.shape # 4x6 = 24 tiles torch.Size([24, 3, 50, 50]) >>> image = torch.rand(3, 400, 600) - >>> tiles = tile_crop(image, tile_size=200) + >>> tiles = TileCrop(image, tile_size=200) >>> tiles.shape # 2x3 = 6 tiles torch.Size([6, 3, 200, 200]) """ - channel_size, height, width = image.shape - - # assert sizes are divisible - assert ( - height % tile_size == 0 and width % tile_size == 0 - ), f"Image size {height}x{width} is not divisible by tile size {tile_size}" - - # Reshape to split height and width into tile_size blocks - tiles_height = height // tile_size - tiles_width = width // tile_size - - reshaped = image.view(channel_size, tiles_height, tile_size, tiles_width, tile_size) - - # Transpose to bring tiles together - # We want [tiles_height, tiles_width, channel_size, tile_size, tile_size] - transposed = reshaped.permute(1, 3, 0, 2, 4) - - # Flatten the tiles - tiles = transposed.contiguous().view( - tiles_height * tiles_width, channel_size, tile_size, tile_size - ) - - return tiles + def forward(self, image: torch.Tensor, tile_size: int) -> torch.Tensor: + channel_size, height, width = image.shape + # assert sizes are divisible + assert ( + height % tile_size == 0 and width % tile_size == 0 + ), f"Image size {height}x{width} is not divisible by tile size {tile_size}" + + # Reshape to split height and width into tile_size blocks + tiles_height = height // tile_size + tiles_width = width // tile_size + + reshaped = image.view( + channel_size, tiles_height, tile_size, tiles_width, tile_size + ) + + # Transpose to bring tiles together + # We want [tiles_height, tiles_width, channel_size, tile_size, tile_size] + transposed = reshaped.permute(1, 3, 0, 2, 4) + + # Flatten the tiles + tiles = transposed.contiguous().view( + tiles_height * tiles_width, channel_size, tile_size, tile_size + ) + return tiles