Skip to content

Commit

Permalink
Export preprocess
Browse files Browse the repository at this point in the history
  • Loading branch information
lucylq committed Jul 29, 2024
1 parent 898670f commit c4055f5
Show file tree
Hide file tree
Showing 5 changed files with 210 additions and 106 deletions.
8 changes: 6 additions & 2 deletions tests/torchtune/models/clip/test_clip_image_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import torch

from torchtune.models.clip._transforms import CLIPImageTransform
from torchtune.models.clip._transforms import CLIPImageTransform, ImageTransformConfig


class TestPipelines:
Expand Down Expand Up @@ -41,16 +41,20 @@ class TestPipelines:
)
def test_clip_image_transform(self, params):

image_transform = CLIPImageTransform(
config = ImageTransformConfig(
image_mean=None,
image_std=None,
tile_size=224,
possible_resolutions=None,
max_num_tiles=4,
resample="bilinear",
resize_to_max_canvas=params["resize_to_max_canvas"],
max_upscaling_size=None,
normalize=False,
)

image_transform = CLIPImageTransform(config=config)

image_size = params["image_size"]

# Create a random image
Expand Down
6 changes: 3 additions & 3 deletions tests/torchtune/modules/transforms/test_tile_crop.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import torch

from torchtune.modules.transforms import tile_crop
from torchtune.modules.transforms import TileCrop


class TestTransforms:
Expand Down Expand Up @@ -55,7 +55,7 @@ def test_tile_crop(self, params):
image = torch.rand(*image_size) # Create a random image tensor

if status == "Passed":
tiles = tile_crop(image, tile_size)
tiles = TileCrop()(image, tile_size)
expected_output_shape = params["expected_output_shape"]
assert (
tiles.shape == expected_output_shape
Expand All @@ -73,7 +73,7 @@ def test_tile_crop(self, params):

elif status == "Failed":
with pytest.raises(Exception) as exc_info:
tile_crop(image, tile_size)
TileCrop()(image, tile_size)
expected_error = params["error"]
actual_error = str(exc_info.value)
assert (
Expand Down
242 changes: 169 additions & 73 deletions torchtune/models/clip/_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,131 @@
# LICENSE file in the root directory of this source tree.

import logging
import math
from dataclasses import dataclass
from typing import Any, List, Mapping, Optional, Tuple

import torch
import torchvision
from PIL import Image

from torchtune.modules.transforms import (
find_supported_resolutions,
get_canvas_best_fit,
resize_with_pad,
tile_crop,
TileCrop,
)
from torchvision.transforms import v2
from torchvision.transforms._functional_tensor import resize

from torchvision.transforms.v2 import functional as F

logger = logging.getLogger(__name__)


@dataclass
class ImageTransformConfig:
"""
image_mean (Optional[List[float]]): Mean values of each channel, used for normalization.
Should be the same used for the pre-trained model. If None, no normalization is performed. Default None.
image_std (Optional[List[float]]): Standard deviation values of each channel, used for normalization.
Should be the same used for the pre-trained model. If None, no normalization is performed. Default None.
possible_resolutions (Optional[List[Tuple[int, int]]]): List of possible resolutions as tuples (height, width).
where each tuple represents a possible canvas to fit the image into when calling ``get_canvas_best_fit``.
If None, this will be calculated using max_num_tiles and tile_size. Default None.
tile_size (int): Size of the tiles to divide the image into. Default 224.
max_num_tiles (Optional[int]): Only used if possible_resolutions is NOT given.
Maximum number of tiles to break an image into.
This will be used to generate possible_resolutions,
e.g. [(224, 224), (224, 448), (448, 224)] if max_num_tiles = 2 and tile_size = 224.
Default 4.
resample (str): Resampling method used when resizing images. Supports any enum of
``torchvision.transforms.InterpolationMode``, e.g. "nearest", "nearest_exact", "bilinear", "bicubic".
Default 'bilinear'.
resize_to_max_canvas (bool): If true, the image will be upscaled without distortion to fit the largest possible
resolution from possible_resolutions.
If False, it will pick the resolution that minimizes downscaling, including no downscaling at all.
In this case, the image will only be upscaled if it's size < tile_size. Default False.
max_upscaling_size (int): Limit the upscaling performed on the image.
normalize (int): If true, image will be normalized using image_mean and image_std.
max_image_size (int): Max image size, used to bound the tensor sizes for export.
"""

image_mean: Optional[List[float]] = None
image_std: Optional[List[float]] = None
possible_resolutions: Optional[List[Tuple[int, int]]] = None
tile_size: int = 224
max_num_tiles: Optional[int] = 4
resample: str = "bilinear"
resize_to_max_canvas: bool = False
max_upscaling_size: Optional[int] = None
normalize: bool = True
max_image_size: int = 5000


class CLIPImageTransformCore(torch.nn.Module):
def __init__(self, config: ImageTransformConfig):
super().__init__()
self.config = config

"""
Performs the core transformations involved in CLIPImageTransform;
1. Resize the image to target_size.
2. Pad the image to canvas_size.
3. Normalize the image using config.image_mean and config.image_std.
4. Reshape the image tensor into [n, channels, config.tile_size, config.tile_size].
Args:
image (torch.Tensor): image as a 3D tensor in form [C, H, W].
target_size (torch.Tensor): tensor of shape [1, 2] containing the target_height and target_width for resize.
canvas_size (torch.Tensor): tensor of shape [1, 2] containing the canvas_height and canvas_width for padding.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Image tensor of shape [n, channels, config.tile_size, config.tile_size]
and aspect ratio tensor of shape [1, 2].
"""

def forward(
self, image: torch.Tensor, target_size: torch.Tensor, canvas_size: torch.Tensor
):
# Resize.
target_h, target_w = target_size.tolist()
torch._check(target_h > 1)
torch._check(target_h <= self.config.max_image_size)
torch._check(target_w > 1)
torch._check(target_w <= self.config.max_image_size)

image = resize(
image,
size=[target_h, target_w],
interpolation=self.config.resample,
antialias=True,
)

# Pad.
canvas_h, canvas_w = canvas_size.tolist()
torch._check(canvas_h > 1)
torch._check(canvas_h <= self.config.max_image_size)
torch._check(canvas_w > 1)
torch._check(canvas_w <= self.config.max_image_size)
sizes = [3, canvas_h, canvas_w]

padding = [0, canvas_h - target_h, 0, canvas_w - target_w]
output = v2.Pad(padding=padding)(image)

# Normalize.
if self.config.normalize:
output = v2.functional.normalize(
output, self.config.image_mean, self.config.image_std
)

# Reshape.
tiles = TileCrop()(output, self.config.tile_size)

# Calculate aspect ratio.
aspect_ratio = canvas_size // self.config.tile_size

return tiles, aspect_ratio


class CLIPImageTransform:
"""
This class accepts images of any size and dynamically resizes, pads, normalizes and tiles it
Expand All @@ -36,42 +143,23 @@ class CLIPImageTransform:
For example, if an input image is of size 300x800, and we want to allow
a maximum of 16 image tiles, with side 224px, then:
If ``resize_to_max_canvas=False``, then:
best_resolution = (448, 896) -> smallest canvas, up to 16 tiles, that doesn't require downscaling
If ``config.resize_to_max_canvas=False``, then:
canvas_size = (448, 896) -> smallest canvas, up to 16 tiles, that doesn't require downscaling
image is NOT resized
image is padded (300, 800) -> 448,896
Image is tiled 2x4, for a final output shape of (8, 3, 224, 224)
If ``resize_to_max_canvas=True``, then:
best_resolution = (448, 1344) # canvas that allows maximum upscaling, with minimum padding, up to 16 tiles
If ``config.resize_to_max_canvas=True``, then:
canvas_size = (448, 1344) # canvas that allows maximum upscaling, with minimum padding, up to 16 tiles
image is resized without distortion (300,800) -> (448, 1194) #448 is the limiting side for the resize
image is padded (448, 1194) -> (448, 1344)
Image is tiled 2x5, for a final output shape of (10, 3, 224, 224)
Args:
image_mean (Optional[List[float]]): Mean values of each channel, used for normalization.
Should be the same used for the pre-trained model. If None, no normalization is performed. Default None.
image_std (Optional[List[float]]): Standard deviation values of each channel, used for normalization.
Should be the same used for the pre-trained model. If None, no normalization is performed. Default None.
possible_resolutions (Optional[List[Tuple[int, int]]]): List of possible resolutions as tuples (height, width).
where each tuple represents a possible canvas to fit the image into when calling ``get_canvas_best_fit``.
If None, this will be calculated using max_num_tiles and tile_size. Default None.
tile_size (int): Size of the tiles to divide the image into. Default 224.
max_num_tiles (Optional[int]): Only used if possible_resolutions is NOT given.
Maximum number of tiles to break an image into.
This will be used to generate possible_resolutions,
e.g. [(224, 224), (224, 448), (448, 224)] if max_num_tiles = 2 and tile_size = 224.
Default 4.
resample (str): Resampling method used when resizing images. Supports any enum of
``torchvision.transforms.InterpolationMode``, e.g. "nearest", "nearest_exact", "bilinear", "bicubic".
Default 'bilinear'.
resize_to_max_canvas (bool): "If True, the image will be upscaled without distortion to fit the largest possible
resolution from possible_resolutions.
If False, it will pick the resolution that minimizes downscaling, including no downscaling at all.
In this case, the image will only be upscaled if it's size < tile_size. Default False.
config (Optional[ImageTransformConfig]): config values for the image transform. If None, uses the default values.
Examples:
>>> image_transform = CLIPImageTransform(
>>> config = ImageTransformConfig(
... image_mean=None,
... image_std=None,
... tile_size=224,
Expand All @@ -80,6 +168,7 @@ class CLIPImageTransform:
... resample="bilinear",
... resize_to_max_canvas=True,
...)
>>> image_transform = CLIPImageTransform(config=config)
>>> # create random image
>>> image = (np.random.rand(100,200,3) * 255).astype(np.uint8)
>>> image = PIL.Image.fromarray(image)
Expand All @@ -92,48 +181,43 @@ class CLIPImageTransform:

def __init__(
self,
image_mean: Optional[List[float]] = None,
image_std: Optional[List[float]] = None,
possible_resolutions: Optional[List[Tuple[int, int]]] = None,
tile_size: int = 224,
max_num_tiles: Optional[int] = 4,
resample: str = "bilinear",
resize_to_max_canvas: bool = False,
config: Optional[ImageTransformConfig] = None,
) -> None:
if config is None:
config = ImageTransformConfig()

# get_canvas_best_fit
assert (
possible_resolutions is not None or max_num_tiles is not None
), f"Either possible_resolutions or max_num_tiles must be given. Got {possible_resolutions=} and {max_num_tiles=}"
config.possible_resolutions is not None or config.max_num_tiles is not None
), (
"Either possible_resolutions or max_num_tiles must be given."
+ f"Got {config.possible_resolutions=} and {config.max_num_tiles=}"
)

# If possible_resolutions are not given, then calculate possible ones based on max_num_tiles
if not possible_resolutions and max_num_tiles:
if not config.possible_resolutions and config.max_num_tiles:
possible_resolutions = find_supported_resolutions(
max_num_tiles=max_num_tiles, tile_size=tile_size
max_num_tiles=config.max_num_tiles, tile_size=config.tile_size
)
config.possible_resolutions = torch.tensor(possible_resolutions).reshape(
-1, 2
)
else:
possible_resolutions = possible_resolutions

self.possible_resolutions = torch.tensor(possible_resolutions).reshape(-1, 2)
logger.info(
f"Found possible_resolutions: {self.possible_resolutions}. Will fit the images into the canvas with best fit."
f"Found possible_resolutions: {config.possible_resolutions}. Will fit the images into the canvas with best fit."
)

self.resize_to_max_canvas = resize_to_max_canvas

# normalize
assert (image_mean is None) == (
image_std is None
), f"Need to provide both or none of image_mean and image_std. Got {image_mean=} and {image_std=}"
self.image_mean = image_mean
self.image_std = image_std
assert (config.image_mean is None) == (
config.image_std is None
), f"Need to provide both or none of image_mean and image_std. Got {config.image_mean=} and {config.image_std=}"

# resize_with_pad
self.max_upscaling_size = None if resize_to_max_canvas else tile_size
self.resample = torchvision.transforms.InterpolationMode[resample.upper()]
config.max_upscaling_size = (
None if config.resize_to_max_canvas else config.tile_size
)

# tile_crop
self.tile_size = tile_size
self.config = config

def __call__(self, *, image: Image.Image, **kwargs) -> Mapping[str, Any]:

Expand All @@ -144,35 +228,47 @@ def __call__(self, *, image: Image.Image, **kwargs) -> Mapping[str, Any]:
F.grayscale_to_rgb_image(F.to_image(image)), scale=True
)

# Find the best canvas to fit the image without distortion
best_resolution = get_canvas_best_fit(
# Find the best canvas to fit the image without distortion.
# This finds the resolution of the best tile arrangement given the
# image dimensions, aspect ratio, and config.num_tiles.
canvas_size = get_canvas_best_fit(
image=image_tensor,
possible_resolutions=self.possible_resolutions,
resize_to_max_canvas=self.resize_to_max_canvas,
possible_resolutions=self.config.possible_resolutions,
resize_to_max_canvas=self.config.resize_to_max_canvas,
)

# resize without distortion + pad to fit best_resolution
image_tensor = resize_with_pad(
image=image_tensor,
target_size=best_resolution,
resample=self.resample,
max_upscaling_size=self.max_upscaling_size,
)
image_height, image_width = image_tensor.shape[-2:]

# Normalize
if self.image_mean and self.image_std:
image_tensor = F.normalize(
image_tensor, mean=self.image_mean, std=self.image_std
# If target_size requires upscaling, we might want to limit the upscaling to max_upscaling_size.
if self.config.max_upscaling_size is not None:
target_height = min(
max(image_height, self.config.max_upscaling_size), canvas_size[0]
)
target_width = min(
max(image_width, self.config.max_upscaling_size), canvas_size[1]
)
target_size = (target_height, target_width)
else:
target_size = canvas_size

# Calculate the target size; the largest aspect ratio preserving size that can fit within the
# canvas size.
scale_h = target_size[0] / image_height
scale_w = target_size[1] / image_width

# Divide the image into equally sized tiles
image_tensor = tile_crop(image=image_tensor, tile_size=self.tile_size)
new_target_height = min(math.floor(image_height * scale_w), target_size[0])
new_target_width = min(math.floor(image_width * scale_h), target_size[1])

aspect_ratio = torch.tensor(best_resolution).reshape(-1) // self.tile_size
# Call ClipImageTransformCore to perform resize, pad and reshape transforms.
tiles, aspect_ratio = CLIPImageTransformCore(self.config)(
image=image_tensor,
target_size=torch.tensor([new_target_height, new_target_width]),
canvas_size=torch.tensor([canvas_size[0], canvas_size[1]]),
)

kwargs.update(
{
"image": image_tensor,
"image": tiles,
"aspect_ratio": aspect_ratio,
}
)
Expand Down
Loading

0 comments on commit c4055f5

Please sign in to comment.