Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Image Generation Dataset #2140

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

calvinpelletier
Copy link
Contributor

@calvinpelletier calvinpelletier commented Dec 10, 2024

Overview

This is an RFC regarding how we should support datasets for finetuning text-conditioned image generation models.

A basic data pipeline for this would be:

  1. Load the JSON/CSV/TSV/Parquet/LMDB/etc. file containing the image paths/urls and captions
  2. For each pair:
    • load/download the image
    • resize the image and optionally randomly augment it (horizontal flip, etc.) and normalize it
    • optionally randomly augment the caption (rearrange caption parts, etc.)
    • tokenize the caption using the model's tokenizer
  3. collate into a batch

At a broad level, this fits well into our current TorchTune data ecosystem (except we wouldn't use the "list of Message objects" abstraction, which would change how we interact with the model's tokenizer).

In TorchTune, a simple version would look something like this:

dataset:
    _component_: torchtune.datasets.img_caption_dataset
    path: ~/my_dataset/data.tsv
    img_transform:
        resize: [256, 256]
        center_crop: true
        horizontal_flip: 0.5
    caption_transform:
        drop: 0.05
        shuffle_parts: 0.1
tokenizer:
    _component_: torchtune.models.flux.FluxTransform
    clip_tokenizer_path: ...
    t5_tokenizer_path: ...
    t5_max_seq_len: 256
def img_caption_dataset(
    model_transform: Transform,
    *,
    path: str,
    img_transform: Config,
    caption_transform: Config,
):
    """Builder for an image caption dataset."""
    data = _load_img_text_dataset(path)
    img_transform = _build_torchvision_transforms(img_transform)
    caption_transform = _CaptionTransform(caption_transform)
    return ImgTextDataset(
        data,
        img_transform=img_transform,
        text_tranform=caption_transform,
        model_transform=model_transform,
    )


def _load_img_text_dataset(path):
    if '.' not in path:
        return datasets.load_dataset(path, ...)

    path = Path(path).expanduser().resolve()
    if path.suffix == ".tsv":
        data = []
        with open(path, "r") as f:
            for line in f:
                img_path_or_url, text = [x.strip() for x in line.split("\t")]
                data.append((img_path_or_url, text))
        return data

    elif path.suffix == "...":
        ...


def _build_torchvision_transforms(cfg):
    """
    Create a series of torchvision transforms
    (resize, crop, flip, etc.)
    """
    ...


class _CaptionTransform:
    """
    Callable that randomly augments image captions with comma-separated parts
    (shuffle parts, randomly drop entire caption, etc.)
    (or does nothing if disabled)
    """

    def __init__(self, cfg): ...

    def __call__(self, caption: str) -> str: ...


class ImgTextDataset(torch.utils.data.Dataset):
    def __init__(self, data, img_transform, text_transform, model_transform):
        self._data = data
        self._img_transform = img_transform
        self._text_transform = text_transform
        self._model_transform = model_transform

    def __len__(self):
        return len(self._data)

    def __getitem__(self, idx):
        img_path_or_url, text = self._data[idx]
        img = (
            Image.open(BytesIO(requests.get(img_path_or_url).content))
            if img_path_or_url.startswith(("http://", "https://", "ftp://", "ftps://"))
            else Image.open(img_path_or_url)
        )
        img = self._img_transform(img)
        text = self._text_transform(text)
        data_dict = self._model_transform(img, text)
        return data_dict


class FluxTransform(Transform):
    def __init__(self, clip_tokenizer_path, t5_tokenizer_path, t5_max_seq_len):
        ...

    def __call__(self, img, text):
        return {
            'img': (img / 127.5) - 1.0,
            'clip_text_tokens': self._clip_tokenizer(text),
            't5_text_tokens': self._t5_tokenizer(text),
        }

TODO: Collate

We'll need to generalize our collate functions such that they can handle data outside of the tokens-and-labels format they currently expect. I will update this section after I've looked into this.

Caching/Preprocessing

From what I've seen online, some people finetune image generators on massive datasets, but most people just finetune on very small personal datasets, often 5-100 images. So we should probably add support for various caching/preprocessing options that increase disk/mem usage in order to achieve faster iterations. Some ideas for optional configurations:

  • cache up to N images in each data worker so they don't have to load them fresh from disk each epoch (this probably isnt an actual bottleneck tho)
  • in the extreme case of like <10 images, we could even just keep the whole dataset on each GPU so we don't have to transfer them each step
  • in the case of a web dataset, save up to N downloaded images to local storage for the next epoch
  • provide a script that would run before training that preprocesses the outputs of frozen parts of the model (text tokens, image autoencoder embeddings) and saves them to disk so that we don't have to recompute every epoch
    • tokenization would be negligible but I bet preprocessing the Flux image encoding would save a lot of time and GPU memory edit: actually the T5 text encoder is the part that would benefit the most from preprocessing
    • this could also be done on the fly, i.e. caching instead of preprocessing. During the first epoch, save the intermediate values to disk and reuse them in all the next epochs. But this makes the code much more complicated.

But we should evaluate whether each of these is worth it:

  • how much performance gain would you actually get? and under what circumstances?
  • how much would it complicate the code and the configs?

Dataset Creation

Should we include scripts/utilities for creating the captions? Users will probably often have just a folder with a bunch of images that they want to finetune on. So we could help them turn that folder into a dataset by using some model to automatically caption them. We could even provide our own models for this by distilling the image captioning capabilities of Llama3.2V-90B into several smaller Llama3.2V models, and let the user pick the one that fits on their device.

We'll also want to support adding words/phrases to the caption that tell the model to generate in the style of this dataset. For example, if I'm finetuning a model on images of myself, I'll want to include something like "a photo of cpelletier" in the caption so that the model learns to associate "cpelletier" with my face. This could be supported at the dataset creation step (i.e. the identifiers are put into the caption data itself, which is simpler), or at the text transform step (i.e. the identifier is specified in the text transform config like 'add "in the style of cpelletier" to the end of each caption', which is a bit more complex but nice that you don't have to change the dataset if you want to experiment with different identifiers).

User Experience

  • Regarding loading the TSV/Parquet/whatever data file, should we just rely on huggingface's load_dataset like we currently do in SFTDataset? It keeps the code simpler, but it makes the user leave torchtune and go read the huggingface docs, which is overkill if they just have some simple JSON file we could easily load ourselves.
  • In addition to absolute image paths in the data file, we should probably support image paths relative to the dataset folder, because it would be super annoying if you had to regenerate your data file any time to move the dataset to a new location.
  • There's currently some potentially unnecessary fields in the config. For example with Flux models, the model determines the image size and the T5 tokenizer sequence length. Is it better to pass this information to the image transform and model transform, respectively? Which complicates the code but lowers the chance of user error. Or is it better to have the user define these values in the dataset config and tokenizer config, respectively? Which puts the burden on the user to match what the model expects.
  • Should we add scripts/utilities for inspecting the dataset? It's nice to see a preview of what a batch looks like, especially when you're messing around with color jitter and other hard-to-configure image augmentations.

Other

  • Naming of the image-text dataset builders/classes? Maybe the more verbose image_caption_dataset_for_image_generation is better to make it clear that this is NOT for something like finetuning a VLM to do image captioning (although maybe it could be generalized to the point where it can also do lists of Message objects and therefore can be used for whatever purpose).
  • Support multiple captions per image? I can imagine people wanting to generate multiple captions for their images, and randomly selecting one at a time during training to prevent overfitting. It's kinda a caption augmentation but it's unique for each caption so it would have to be supported at the data level.

@calvinpelletier calvinpelletier added the rfc Request for comments label Dec 10, 2024
Copy link

pytorch-bot bot commented Dec 10, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2140

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 4e6b320 with merge base 06a8379 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 10, 2024
@gau-nernst
Copy link
Contributor

Saw this RFC and want to pen down some of my thoughts as I'm also building some fine-tuning pipeline for Flux.

  • "preprocessing the Flux image encoding would save a lot of time and GPU memory" - I assume this refer to FLUX autoencoder to encode image from pixel space to latent space. Actually FLUX autoencoder is small (compared to the rest - only 168MB in BF16) and fast, so caching it is not quite necessary.
  • The more problematic one is T5 embeddings. T5 encoder (xxl variant) alone is huge - 9.5GB in BF16. So T5 embedding is a good candidate for pre-computation. But the problem is T5 embedding is huge - with size (512,4096) for each prompt, it would be 4MB in BF16. It should be fine for small datasets, but not quite scalable (though SD3 paper mentions they cache all T5 embeddings 😆)
  • Regarding generate caption with an LLM, I do find that it's easier (and better quality) to just use an online service that accepts image inputs. Won't mention competitor name here 😅, but some of them provide good free-tier with API access.
  • "Support multiple captions per image?" - For local datasets, we can just add extra rows (different prompts but same path to image). For streaming datasets, then yes we probably need to handle a row with multiple prompts (and 1 image).
  • (This will be much harder) Resolution/aspect-ratio bucketing for datasets with diverse resolution/aspect ratio. For different aspect ratios, random crop is not ideal since it can truncate important parts of an image. For different resolutions, we don't want to enlarge small images (since it will be low quality/blurry), or make large images smaller (since we lose out the ability to learn high resolution)

def __call__(self, caption: str) -> str: ...


class ImgTextDataset(torch.utils.data.Dataset):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should basically look like the text completion dataset but with model_transform as you have here and column_map instead of column. I also don't know if we want to anchor this to vision as diffusion for audio etc would use the same pattern. I think an optional data_transform would work here instead of img/text.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should include an image transform here that way we can separate model-independent image augmentations from model-specific ones in the model transform.

Also, I don't think this should be a generic dataset class for diffusion in general. It shouldn't be tied to diffusion at all, and instead be for any downstream task that uses image-text pairs, e.g. non-diffusion image gen models, image captioning models, image-text joint encoders, etc. There were a lot of papers at NeurIPS this year that was finetuning CLIP. I would expect this to utilize the same ImageTextDataset as finetuning Flux would. If you're doing diffusion for audio, you would use an AudioTextDataset

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with you WRT a dataset class (Hence why everything so far essentially returns an SFT dataset). However, I do think there's tremendous value in aligning our dataset builders with specific tasks. It makes it easier to utilize from configs and find datasets to use on the Hub.

return data_dict


class FluxTransform(Transform):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generic diffusion model transform will just take a dict instead of a list of messages but otherwise be the same

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is logic here that is specific to Flux and I think it should exist withing a Flux-specific model transform

...


def _build_torchvision_transforms(cfg):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This along with CaptionTransform is all within the abstraction of model transform or data transform as the user needs. Or is this meant to be an example?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should separate the data transform logic from the model transform logic, e.g. the data augmentations like horizontal flip would be in a img transform that's entirely separate from model logic, and the model-specific logic like image normalization would be in the model transform

)


def _load_img_text_dataset(path):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use huggingface load_dataset as well as load_image

Copy link
Contributor Author

@calvinpelletier calvinpelletier Dec 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding load_image, thanks I'll switch to this. Question though: when the image path is a URL, should we include the option for saving these images to disk so that they don't need to be re-downloaded during the next epoch?

Regarding load_dataset, I address this in the first bullet of the user experience section. I personally think it's better if we handle simple cases like loading a image-caption TSV ourselves so the user doesn't have to go read huggingface docs, especially since most img gen finetuning will be done on small local datasets, but I'm also ok with just relying on huggingface's load_dataset since that does make our code simpler

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally understand your point on not wanting to overcomplicate things, but using load_dataset under the hood makes our lives way easier lol


```yaml
dataset:
_component_: torchtune.datasets.img_caption_dataset
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe just my naivate, but when I hear image-caption dataset, I assume it's a dataset for taking an image and generating a caption, which is not the case here.

Hugging Face has a label for these datasets called "Text-to-Image", which I think is a more accurate description. This also is inline with our addition of task-centered dataset builders like the vqa_dataset.

Concretely proposing changing the default dataset for diffusion from img_caption_dataset to text_to_image_dataset

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I figured that this dataset could be used for any downstream task that uses pairs of images+text. Like finetuning CLIP for example. Maybe image_text_pair_dataset? Or is it more clear for the user if we name the datasets based on a specific use of it?

```yaml
dataset:
_component_: torchtune.datasets.img_caption_dataset
path: ~/my_dataset/data.tsv
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it more common to have ahem private data to finetune diffusion models, or data that might be published on the Hugging Face Hub? That should affect what the first-class citizen is here and what goes in all our examples.

Regardless, if we're using the load_dataset functionality from Hugging Face (like we do for all our other datasets including image-to-text), why does this not follow the same format where we specify e.g. TSV as the source and data_files=~/my_dataset/data.tsv?

resize: [256, 256]
center_crop: true
horizontal_flip: 0.5
caption_transform:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above comment, but would opt for text not caption here.

caption_transform:
drop: 0.05
shuffle_parts: 0.1
tokenizer:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you show how this would look from code? I know we prefer flattened params for our configs, but if this was build via code I'd imagine we'd instantiate Clip and T5 and then pass that to our FluxTransform - right?

model_transform: Transform,
*,
path: str,
img_transform: Config,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we ever want our builders to see the notion of configs. Configs are just a way to interface with our recipes, but builders should be able to be dropped into place anywhere.

def __call__(self, caption: str) -> str: ...


class ImgTextDataset(torch.utils.data.Dataset):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with you WRT a dataset class (Hence why everything so far essentially returns an SFT dataset). However, I do think there's tremendous value in aligning our dataset builders with specific tasks. It makes it easier to utilize from configs and find datasets to use on the Hub.

# User Experience

- Regarding loading the TSV/Parquet/whatever data file, should we just rely on huggingface's `load_dataset` like we currently do in `SFTDataset`? It keeps the code simpler, but it makes the user leave torchtune and go read the huggingface docs, which is overkill if they just have some simple JSON file we could easily load ourselves.
- In addition to absolute image paths in the data file, we should probably support image paths relative to the dataset folder, because it would be super annoying if you had to regenerate your data file any time to move the dataset to a new location.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is handled via our current image/text dataset utilities.

- Regarding loading the TSV/Parquet/whatever data file, should we just rely on huggingface's `load_dataset` like we currently do in `SFTDataset`? It keeps the code simpler, but it makes the user leave torchtune and go read the huggingface docs, which is overkill if they just have some simple JSON file we could easily load ourselves.
- In addition to absolute image paths in the data file, we should probably support image paths relative to the dataset folder, because it would be super annoying if you had to regenerate your data file any time to move the dataset to a new location.
- There's currently some potentially unnecessary fields in the config. For example with Flux models, the model determines the image size and the T5 tokenizer sequence length. Is it better to pass this information to the image transform and model transform, respectively? Which complicates the code but lowers the chance of user error. Or is it better to have the user define these values in the dataset config and tokenizer config, respectively? Which puts the burden on the user to match what the model expects.
- Should we add scripts/utilities for inspecting the dataset? It's nice to see a preview of what a batch looks like, especially when you're messing around with color jitter and other hard-to-configure image augmentations.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely a cool feature, but probably a P2 or upon-request-from-users type of thing.


# Other
- Naming of the image-text dataset builders/classes? Maybe the more verbose `image_caption_dataset_for_image_generation` is better to make it clear that this is NOT for something like finetuning a VLM to do image captioning (although maybe it could be generalized to the point where it can also do lists of Message objects and therefore can be used for whatever purpose).
- Support multiple captions per image? I can imagine people wanting to generate multiple captions for their images, and randomly selecting one at a time during training to prevent overfitting. It's kinda a caption augmentation but it's unique for each caption so it would have to be supported at the data level.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be possible to do easily with torchtune, but definitely not OOTB.

In TorchTune, a simple version would look something like this:

```yaml
dataset:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay one big question: What direction are we trying to go in?

We landed torchdata support which started a refactor of our datasets into dataset-specific utils rather than an entire builder that essentially just spits back an SFT datasets class. IMO this means less code for the user to worry about and makes hacking easier. In addition, this gives us all the benefits from torchdata.

If we believe torchdata is the right way to go (especially for these more data-intensive use cases), then should this be refactored towards that end?

cc @pbontrager @ebsmothers

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The goal of this was to follow the pattern of our current SFT dataset solution so it'd be easier to move in parallel to the torchdata solution. By following close to STF then it should be trivial to convert this to the torchdata solution once that's finalized.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. rfc Request for comments
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants