Skip to content

Commit

Permalink
Export preprocess
Browse files Browse the repository at this point in the history
  • Loading branch information
lucylq committed Aug 1, 2024
1 parent 898670f commit f5296a9
Show file tree
Hide file tree
Showing 8 changed files with 188 additions and 333 deletions.
3 changes: 1 addition & 2 deletions docs/source/api_ref_modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ Functions used for preprocessing images.
:nosignatures:

transforms.get_canvas_best_fit
transforms.resize_with_pad
transforms.tile_crop
transforms.find_supported_resolutions
transforms.VisionCrossAttentionMask
transforms.TileCrop
84 changes: 0 additions & 84 deletions tests/torchtune/modules/transforms/test_resize_with_pad.py

This file was deleted.

8 changes: 4 additions & 4 deletions tests/torchtune/modules/transforms/test_tile_crop.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import torch

from torchtune.modules.transforms import tile_crop
from torchtune.modules.transforms import TileCrop


class TestTransforms:
Expand Down Expand Up @@ -51,11 +51,11 @@ def test_tile_crop(self, params):
image_size = params["image_size"]
tile_size = params["tile_size"]
status = params["status"]

tile_crop = TileCrop(tile_size)
image = torch.rand(*image_size) # Create a random image tensor

if status == "Passed":
tiles = tile_crop(image, tile_size)
tiles = tile_crop(image)
expected_output_shape = params["expected_output_shape"]
assert (
tiles.shape == expected_output_shape
Expand All @@ -73,7 +73,7 @@ def test_tile_crop(self, params):

elif status == "Failed":
with pytest.raises(Exception) as exc_info:
tile_crop(image, tile_size)
tile_crop(image)
expected_error = params["error"]
actual_error = str(exc_info.value)
assert (
Expand Down
150 changes: 130 additions & 20 deletions torchtune/models/clip/_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,115 @@
# LICENSE file in the root directory of this source tree.

import logging
import math
from typing import Any, List, Mapping, Optional, Tuple

import torch
import torchvision
from PIL import Image

from torchtune.modules.transforms import (
find_supported_resolutions,
get_canvas_best_fit,
resize_with_pad,
tile_crop,
TileCrop,
)
from torchvision.transforms import v2
from torchvision.transforms._functional_tensor import resize

from torchvision.transforms.v2 import functional as F

logger = logging.getLogger(__name__)


class CLIPImageTransformCore(torch.nn.Module):
def __init__(
self,
resample: str,
image_mean: Optional[List[float]],
image_std: Optional[List[float]],
tile_size: int,
max_num_tiles: int,
antialias: bool,
):
super().__init__()
self.resample = resample
self.image_mean = image_mean
self.image_std = image_std
self.tile_size = tile_size
self.max_num_tiles = max_num_tiles
self.antialias = antialias
self.tile_crop = TileCrop(tile_size)

def check_variable_bounds_for_export(
self, vars: List[int], lower: int, upper: int
) -> None:
"""
Performs torch._checks to confirm a value is within the specified lower and upper bounds.
Note: this is used to export the model. For eager mode usage, please disregard.
The check mitigates data dependent errors that may occur during torch.export. It installs a
deferred runtime assert, instead of a compile-time guard. Data dependent errors usually occur
in models with data-dependent control flow, eg. via .item(), tolist(), nonzero(). For more
context: https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit
"""
for var in vars:
torch._check(var >= lower)
torch._check(var <= upper)

def forward(
self, image: torch.Tensor, target_size: torch.Tensor, canvas_size: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Performs the core transformations involved in CLIPImageTransform;
1. Resize the image to target_size.
2. Pad the image to canvas_size.
3. Normalize the image using image_mean and image_std.
4. Reshape the image tensor into [n, channels, tile_size, tile_size].
Args:
image (torch.Tensor): image as a 3D tensor in form [C, H, W].
target_size (torch.Tensor): tensor of shape (2,) containing the target_height and target_width for resize.
canvas_size (torch.Tensor): tensor of shape (2,) containing the canvas_height and canvas_width for padding.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Image tensor of shape [n, channels, tile_size, tile_size]
and aspect ratio tensor of shape [1, 2].
"""

target_h, target_w = target_size.tolist()
canvas_h, canvas_w = canvas_size.tolist()

# Checks to allow the model to export via torch.export.
self.check_variable_bounds_for_export(
[target_h, target_w, canvas_h, canvas_w],
2,
self.tile_size * self.max_num_tiles,
)

# Resize.
image = resize(
image,
size=[target_h, target_w],
interpolation=self.resample,
antialias=self.antialias,
)

# Pad, such that the image is on the top-left and padded on the right-bottom.
padding = [0, canvas_h - target_h, 0, canvas_w - target_w]
output = v2.Pad(padding=padding)(image)

# Normalize.
if self.image_mean is not None and self.image_std is not None:
output = v2.functional.normalize(output, self.image_mean, self.image_std)

# Reshape.
tiles = self.tile_crop(output)

# Calculate aspect ratio.
aspect_ratio = canvas_size // self.tile_size

return tiles, aspect_ratio


class CLIPImageTransform:
"""
This class accepts images of any size and dynamically resizes, pads, normalizes and tiles it
Expand Down Expand Up @@ -120,6 +211,7 @@ def __init__(
)

self.resize_to_max_canvas = resize_to_max_canvas
self.max_num_tiles = max_num_tiles

# normalize
assert (image_mean is None) == (
Expand All @@ -128,13 +220,22 @@ def __init__(
self.image_mean = image_mean
self.image_std = image_std

# resize_with_pad
# resize
self.max_upscaling_size = None if resize_to_max_canvas else tile_size
self.resample = torchvision.transforms.InterpolationMode[resample.upper()]
self.resample = resample

# tile_crop
self.tile_size = tile_size

self.core_transform = CLIPImageTransformCore(
resample=self.resample,
image_mean=self.image_mean,
image_std=self.image_std,
tile_size=self.tile_size,
max_num_tiles=self.max_num_tiles,
antialias=True,
)

def __call__(self, *, image: Image.Image, **kwargs) -> Mapping[str, Any]:

assert isinstance(image, Image.Image), "Input image must be a PIL image."
Expand All @@ -151,28 +252,37 @@ def __call__(self, *, image: Image.Image, **kwargs) -> Mapping[str, Any]:
resize_to_max_canvas=self.resize_to_max_canvas,
)

# resize without distortion + pad to fit best_resolution
image_tensor = resize_with_pad(
image=image_tensor,
target_size=best_resolution,
resample=self.resample,
max_upscaling_size=self.max_upscaling_size,
)
image_height, image_width = image_tensor.shape[-2:]

# Normalize
if self.image_mean and self.image_std:
image_tensor = F.normalize(
image_tensor, mean=self.image_mean, std=self.image_std
# If target_size requires upscaling, we might want to limit the upscaling to max_upscaling_size.
if self.max_upscaling_size is not None:
target_height = min(
max(image_height, self.max_upscaling_size), best_resolution[0]
)
target_width = min(
max(image_width, self.max_upscaling_size), best_resolution[1]
)
target_size = (target_height, target_width)
else:
target_size = best_resolution

# Divide the image into equally sized tiles
image_tensor = tile_crop(image=image_tensor, tile_size=self.tile_size)
# Calculate the largest aspect ratio preserving size that fits best_resolution.
scale_h = target_size[0] / image_height
scale_w = target_size[1] / image_width

aspect_ratio = torch.tensor(best_resolution).reshape(-1) // self.tile_size
new_target_height = min(math.floor(image_height * scale_w), target_size[0])
new_target_width = min(math.floor(image_width * scale_h), target_size[1])

# Call ClipImageTransformCore to perform resize, pad, normalize and reshape transforms.
tiles, aspect_ratio = self.core_transform(
image=image_tensor,
target_size=torch.tensor([new_target_height, new_target_width]),
canvas_size=torch.tensor(best_resolution),
)

kwargs.update(
{
"image": image_tensor,
"image": tiles,
"aspect_ratio": aspect_ratio,
}
)
Expand Down
9 changes: 3 additions & 6 deletions torchtune/modules/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,13 @@
find_supported_resolutions,
get_canvas_best_fit,
)
from torchtune.modules.transforms.vision_utils.resize_with_pad import ( # noqa
resize_with_pad,
)
from torchtune.modules.transforms.vision_utils.tile_crop import tile_crop # noqa

from torchtune.modules.transforms.vision_utils.tile_crop import TileCrop # noqa

__all__ = [
"TileCrop",
"Transform",
"get_canvas_best_fit",
"resize_with_pad",
"tile_crop",
"find_supported_resolutions",
"VisionCrossAttentionMask",
]
Loading

0 comments on commit f5296a9

Please sign in to comment.