Skip to content

Latest commit

 

History

History
373 lines (265 loc) · 15.8 KB

File metadata and controls

373 lines (265 loc) · 15.8 KB

NVLabs Sana Quickstart

In this example, we'll be full-rank training the NVLabs Sana model.

Hardware requirements

Sana is very lightweight and might not even need full gradient checkpointing enabled on a 24G card, which means it trains very quickly!

  • the absolute minimum is about 12G VRAM, though this guide might not help you get there fully
  • a realistic minimum is a single 3090 or V100 GPU
  • ideally multiple 4090, A6000, L40S, or better

Sana is a strange architecture relative to other models that are trainable by SimpleTuner;

  • Initially, unlike other models, Sana required fp16 training and would crash out with bf16
    • Model authors at NVIDIA were gracious enough to follow-up with bf16-compatible weights for fine-tuning
  • Quantisation might be more sensitive on this model family due to the issues with bf16/fp16
  • SageAttention does not work with Sana (yet) due to its head_dim shape that is currently unsupported
  • The loss value when training Sana is very high, and it might need a much lower learning rate than other models (eg. 1e-5 or thereabouts)
  • Training might hit NaN values, and it's not clear why this happens

Gradient checkpointing can free VRAM, but slows down training. A chart of test results from a 4090 with 5800X3D:

image

SimpleTuner's Sana modeling code allows the specification of --gradient_checkpointing_interval to checkpoint every n blocks and attain the results seen in the above chart.

Prerequisites

Make sure that you have python installed; SimpleTuner does well with 3.10 or 3.11. Python 3.12 should not be used.

You can check this by running:

python --version

If you don't have python 3.11 installed on Ubuntu, you can try the following:

apt -y install python3.11 python3.11-venv

Container image dependencies

For Vast, RunPod, and TensorDock (among others), the following will work on a CUDA 12.2-12.4 image:

apt -y install nvidia-cuda-toolkit libgl1-mesa-glx

If libgl1-mesa-glx is not found, you might need to use libgl1-mesa-dri instead. Your mileage may vary.

Installation

Clone the SimpleTuner repository and set up the python venv:

git clone --branch=release https://github.com/bghira/SimpleTuner.git

cd SimpleTuner

# if python --version shows 3.11 you can just also use the 'python' command here.
python3.11 -m venv .venv

source .venv/bin/activate

pip install -U poetry pip

# Necessary on some systems to prevent it from deciding it knows better than us.
poetry config virtualenvs.create false

Note: We're currently installing the release branch here; the main branch may contain experimental features that might have better results or lower memory use.

Depending on your system, you will run one of 3 commands:

# Linux
poetry install

Setting up the environment

To run SimpleTuner, you will need to set up a configuration file, the dataset and model directories, and a dataloader configuration file.

Configuration file

An experimental script, configure.py, may allow you to entirely skip this section through an interactive step-by-step configuration. It contains some safety features that help avoid common pitfalls.

Note: This doesn't configure your dataloader. You will still have to do that manually, later.

To run it:

python configure.py

⚠️ For users located in countries where Hugging Face Hub is not readily accessible, you should add HF_ENDPOINT=https://hf-mirror.com to your ~/.bashrc or ~/.zshrc depending on which $SHELL your system uses.

If you prefer to manually configure:

Copy config/config.json.example to config/config.json:

cp config/config.json.example config/config.json

There, you will possibly need to modify the following variables:

  • model_type - Set this to full.

  • model_family - Set this to sana.

  • pretrained_model_name_or_path - Set this to terminusresearch/sana-1.6b-1024px

  • output_dir - Set this to the directory where you want to store your checkpoints and validation images. It's recommended to use a full path here.

  • train_batch_size - for a 24G card with full gradient checkpointing, this can be as high as 6.

  • validation_resolution - This checkpoint for Sana is a 1024px model, you should set this to 1024x1024 or one of Sana's other supported resolutions.

    • Other resolutions may be specified using commas to separate them: 1024x1024,1280x768,2048x2048
  • validation_guidance - Use whatever you are used to selecting at inference time for Sana.

  • validation_num_inference_steps - Use somewhere around 50 for the best quality, though you can accept less if you're happy with the results.

  • use_ema - setting this to true will greatly help obtain a more smoothed result alongside your main trained checkpoint.

  • optimizer - You can use any optimiser you are comfortable and familiar with, but we will use optimi-adamw for this example.

  • mixed_precision - It's recommended to set this to bf16 for the most efficient training configuration, or no (but will consume more memory and be slower).

    • A value of fp16 is not recommended here but may be required for certain Sana finetunes (and introduces other new issues to enable this)
  • gradient_checkpointing - Disabling this will go the fastest, but limits your batch sizes. It is required to enable this to get the lowest VRAM usage.

  • gradient_checkpointing_interval - If gradient_checkpointing feels like overkill on your GPU, you could set this to a value of 2 or higher to only checkpoint every n blocks. A value of 2 would checkpoint half of the blocks, and 3 would be one-third.

Multi-GPU users can reference this document for information on configuring the number of GPUs to use.

Validation prompts

Inside config/config.json is the "primary validation prompt", which is typically the main instance_prompt you are training on for your single subject or style. Additionally, a JSON file may be created that contains extra prompts to run through during validations.

The example config file config/user_prompt_library.json.example contains the following format:

{
  "nickname": "the prompt goes here",
  "another_nickname": "another prompt goes here"
}

The nicknames are the filename for the validation, so keep them short and compatible with your filesystem.

To point the trainer to this prompt library, add it to TRAINER_EXTRA_ARGS by adding a new line at the end of config.json:

  "--user_prompt_library": "config/user_prompt_library.json",

A set of diverse prompt will help determine whether the model is collapsing as it trains. In this example, the word <token> should be replaced with your subject name (instance_prompt).

{
    "anime_<token>": "a breathtaking anime-style portrait of <token>, capturing her essence with vibrant colors and expressive features",
    "chef_<token>": "a high-quality, detailed photograph of <token> as a sous-chef, immersed in the art of culinary creation",
    "just_<token>": "a lifelike and intimate portrait of <token>, showcasing her unique personality and charm",
    "cinematic_<token>": "a cinematic, visually stunning photo of <token>, emphasizing her dramatic and captivating presence",
    "elegant_<token>": "an elegant and timeless portrait of <token>, exuding grace and sophistication",
    "adventurous_<token>": "a dynamic and adventurous photo of <token>, captured in an exciting, action-filled moment",
    "mysterious_<token>": "a mysterious and enigmatic portrait of <token>, shrouded in shadows and intrigue",
    "vintage_<token>": "a vintage-style portrait of <token>, evoking the charm and nostalgia of a bygone era",
    "artistic_<token>": "an artistic and abstract representation of <token>, blending creativity with visual storytelling",
    "futuristic_<token>": "a futuristic and cutting-edge portrayal of <token>, set against a backdrop of advanced technology",
    "woman": "a beautifully crafted portrait of a woman, highlighting her natural beauty and unique features",
    "man": "a powerful and striking portrait of a man, capturing his strength and character",
    "boy": "a playful and spirited portrait of a boy, capturing youthful energy and innocence",
    "girl": "a charming and vibrant portrait of a girl, emphasizing her bright personality and joy",
    "family": "a heartwarming and cohesive family portrait, showcasing the bonds and connections between loved ones"
}

ℹ️ Sana uses an odd text encoder configuration that means shorter prompts will possibly look very bad.

CLIP score tracking

If you wish to enable evaluations to score the model's performance, see this document for information on configuring and interpreting CLIP scores.

Sana time schedule shifting

Flow-matching models such as Sana, Sana, and SD3 have a property called "shift" that allows us to shift the trained portion of the timestep schedule using a simple decimal value.

Auto-shift

A commonly-recommended approach is to follow several recent works and enable resolution-dependent timestep shift, --flux_schedule_auto_shift which uses higher shift values for larger images, and lower shift values for smaller images. This results in stable but potentially mediocre training results.

Manual specification

Thanks to General Awareness from Discord for the following examples

When using a --flux_schedule_shift value of 0.1 (a very low value), only the finer details of the image are affected: image

When using a --flux_schedule_shift value of 4.0 (a very high value), the large compositional features and potentially colour space of the model becomes impacted: image

Dataset considerations

⚠️ Image quality for training is more important for Sana than for most other models, as it will absorb the artifacts in your images first, and then learn the concept/subject.

It's crucial to have a substantial dataset to train your model on. There are limitations on the dataset size, and you will need to ensure that your dataset is large enough to train your model effectively. Note that the bare minimum dataset size is train_batch_size * gradient_accumulation_steps as well as more than vae_batch_size. The dataset will not be useable if it is too small.

ℹ️ With few enough images, you might see a message no images detected in dataset - increasing the repeats value will overcome this limitation.

Depending on the dataset you have, you will need to set up your dataset directory and dataloader configuration file differently. In this example, we will be using pseudo-camera-10k as the dataset.

Create a --data_backend_config (config/multidatabackend.json) document containing this:

[
  {
    "id": "pseudo-camera-10k-sana",
    "type": "local",
    "crop": true,
    "crop_aspect": "square",
    "crop_style": "center",
    "resolution": 512,
    "minimum_image_size": 512,
    "maximum_image_size": 512,
    "target_downsample_size": 512,
    "resolution_type": "pixel_area",
    "cache_dir_vae": "cache/vae/sana/pseudo-camera-10k",
    "instance_data_dir": "datasets/pseudo-camera-10k",
    "disabled": false,
    "skip_file_discovery": "",
    "caption_strategy": "filename",
    "metadata_backend": "discovery",
    "repeats": 0,
    "is_regularisation_data": true
  },
  {
    "id": "dreambooth-subject",
    "type": "local",
    "crop": false,
    "resolution": 1024,
    "minimum_image_size": 1024,
    "maximum_image_size": 1024,
    "target_downsample_size": 1024,
    "resolution_type": "pixel_area",
    "cache_dir_vae": "cache/vae/sana/dreambooth-subject",
    "instance_data_dir": "datasets/dreambooth-subject",
    "caption_strategy": "instanceprompt",
    "instance_prompt": "the name of your subject goes here",
    "metadata_backend": "discovery",
    "repeats": 1000
  },
  {
    "id": "dreambooth-subject-512",
    "type": "local",
    "crop": false,
    "resolution": 512,
    "minimum_image_size": 512,
    "maximum_image_size": 512,
    "target_downsample_size": 512,
    "resolution_type": "pixel_area",
    "cache_dir_vae": "cache/vae/sana/dreambooth-subject-512",
    "instance_data_dir": "datasets/dreambooth-subject",
    "caption_strategy": "instanceprompt",
    "instance_prompt": "the name of your subject goes here",
    "metadata_backend": "discovery",
    "repeats": 1000
  },
  {
    "id": "text-embeds",
    "type": "local",
    "dataset_type": "text_embeds",
    "default": true,
    "cache_dir": "cache/text/sana",
    "disabled": false,
    "write_batch_size": 128
  }
]

ℹ️ Running 512px and 1024px datasets concurrently is supported, and could result in better convergence for Sana.

Then, create a datasets directory:

mkdir -p datasets
pushd datasets
    huggingface-cli download --repo-type=dataset bghira/pseudo-camera-10k --local-dir=pseudo-camera-10k
    mkdir dreambooth-subject
    # place your images into dreambooth-subject/ now
popd

This will download about 10k photograph samples to your datasets/pseudo-camera-10k directory, which will be automatically created for you.

Your Dreambooth images should go into the datasets/dreambooth-subject directory.

Login to WandB and Huggingface Hub

You'll want to login to WandB and HF Hub before beginning training, especially if you're using --push_to_hub and --report_to=wandb.

If you're going to be pushing items to a Git LFS repository manually, you should also run git config --global credential.helper store

Run the following commands:

wandb login

and

huggingface-cli login

Follow the instructions to log in to both services.

Executing the training run

From the SimpleTuner directory, one simply has to run:

./train.sh

This will begin the text embed and VAE output caching to disk.

For more information, see the dataloader and tutorial documents.

Notes & troubleshooting tips

Lowest VRAM config

Currently, the lowest VRAM utilisation can be attained with:

  • OS: Ubuntu Linux 24
  • GPU: A single NVIDIA CUDA device (10G, 12G)
  • System memory: 50G of system memory approximately
  • Base model precision: bnb-nf4
  • Optimiser: Lion 8Bit Paged, bnb-lion8bit-paged
  • Resolution: 1024px
  • Batch size: 1, zero gradient accumulation steps
  • DeepSpeed: disabled / unconfigured
  • PyTorch: 2.5.1
  • Using --quantize_via=cpu to avoid outOfMemory error during startup on <=16G cards.
  • Enable --gradient_checkpointing

NOTE: Pre-caching of VAE embeds and text encoder outputs may use more memory and still OOM. If so, text encoder quantisation can be enabled. VAE tiling may not work for Sana at this time.

Speed was approximately 1.4 iterations per second on a 4090.

Masked loss

If you are training a subject or style and would like to mask one or the other, see the masked loss training section of the Dreambooth guide.

Quantisation

Not tested thoroughly (yet).

Learning rates

LoRA (--lora_type=standard)

Not supported.

LoKr (--lora_type=lycoris)

  • Mild learning rates are better for LoKr (1e-4 with AdamW, 2e-5 with Lion)
  • Other algo need more exploration.
  • Setting is_regularisation_data has unknown impact/effect with Sana (not tested)

Image artifacts

Sana has an unknown response to image artifacts.

It's not currently known whether any common training artifacts will be produced or what the cause of these might be.

If any image quality issues arise, please open an issue on Github.

Aspect bucketing

This model has an unknown response to aspect bucketed data. Experimentation will be helpful.