From 82aacec453244034285b40bb120f67e49fc0a694 Mon Sep 17 00:00:00 2001 From: bghira Date: Fri, 13 Dec 2024 07:01:53 -0600 Subject: [PATCH] sana: update quickstart to mention bf16 weights --- documentation/quickstart/SANA.md | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/documentation/quickstart/SANA.md b/documentation/quickstart/SANA.md index 964f4d76..2dd355b9 100644 --- a/documentation/quickstart/SANA.md +++ b/documentation/quickstart/SANA.md @@ -12,16 +12,19 @@ Sana is very lightweight and might not even need full gradient checkpointing ena Sana is a strange architecture relative to other models that are trainable by SimpleTuner; -- It requires FP16 training, unlike other models, this **will not work** with BF16 -- It will not be happy with model quantisation due to the need to run in FP16; most quantisation methods require the use of BF16 - - NF4 looks like it might work, but hasn't been fully tested -- SageAttention does not work with Sana due to the shapes inside the model +- Initially, unlike other models, Sana required fp16 training and would crash out with bf16 + - Model authors at NVIDIA were gracious enough to follow-up with bf16-compatible weights for fine-tuning +- Quantisation might be more sensitive on this model family due to the issues with bf16/fp16 +- SageAttention does not work with Sana (yet) due to its head_dim shape that is currently unsupported - The loss value when training Sana is very high, and it might need a much lower learning rate than other models (eg. `1e-5` or thereabouts) +- Training might hit NaN values, and it's not clear why this happens Gradient checkpointing can free VRAM, but slows down training. A chart of test results from a 4090 with 5800X3D: ![image](https://github.com/user-attachments/assets/310bf099-a077-4378-acf4-f60b4b82fdc4) +SimpleTuner's Sana modeling code allows the specification of `--gradient_checkpointing_interval` to checkpoint every _n_ blocks and attain the results seen in the above chart. + ### Prerequisites Make sure that you have python installed; SimpleTuner does well with 3.10 or 3.11. **Python 3.12 should not be used**. @@ -117,8 +120,9 @@ There, you will possibly need to modify the following variables: - `validation_num_inference_steps` - Use somewhere around 50 for the best quality, though you can accept less if you're happy with the results. - `use_ema` - setting this to `true` will greatly help obtain a more smoothed result alongside your main trained checkpoint. -- `optimizer` - Since Sana requires fp16, some optimisers like `adamw_bf16` will not work with it. You can use `optimi-lion`, `optimi-stableadamw` or others you are familiar with instead. -- `mixed_precision` - This gets overridden to `no` for you anyway, since we rely on fp16 training. +- `optimizer` - You can use any optimiser you are comfortable and familiar with, but we will use `optimi-adamw` for this example. +- `mixed_precision` - It's recommended to set this to `bf16` for the most efficient training configuration, or `no` (but will consume more memory and be slower). + - A value of `fp16` is not recommended here but may be required for certain Sana finetunes (and introduces other new issues to enable this) - `gradient_checkpointing` - Disabling this will go the fastest, but limits your batch sizes. It is required to enable this to get the lowest VRAM usage. - `gradient_checkpointing_interval` - If `gradient_checkpointing` feels like overkill on your GPU, you could set this to a value of 2 or higher to only checkpoint every _n_ blocks. A value of 2 would checkpoint half of the blocks, and 3 would be one-third.