Skip to content

Commit

Permalink
sana: update quickstart to mention bf16 weights
Browse files Browse the repository at this point in the history
  • Loading branch information
bghira committed Dec 13, 2024
1 parent e3791b9 commit 82aacec
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions documentation/quickstart/SANA.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,19 @@ Sana is very lightweight and might not even need full gradient checkpointing ena

Sana is a strange architecture relative to other models that are trainable by SimpleTuner;

- It requires FP16 training, unlike other models, this **will not work** with BF16
- It will not be happy with model quantisation due to the need to run in FP16; most quantisation methods require the use of BF16
- NF4 looks like it might work, but hasn't been fully tested
- SageAttention does not work with Sana due to the shapes inside the model
- Initially, unlike other models, Sana required fp16 training and would crash out with bf16
- Model authors at NVIDIA were gracious enough to follow-up with bf16-compatible weights for fine-tuning
- Quantisation might be more sensitive on this model family due to the issues with bf16/fp16
- SageAttention does not work with Sana (yet) due to its head_dim shape that is currently unsupported
- The loss value when training Sana is very high, and it might need a much lower learning rate than other models (eg. `1e-5` or thereabouts)
- Training might hit NaN values, and it's not clear why this happens

Gradient checkpointing can free VRAM, but slows down training. A chart of test results from a 4090 with 5800X3D:

![image](https://github.com/user-attachments/assets/310bf099-a077-4378-acf4-f60b4b82fdc4)

SimpleTuner's Sana modeling code allows the specification of `--gradient_checkpointing_interval` to checkpoint every _n_ blocks and attain the results seen in the above chart.

### Prerequisites

Make sure that you have python installed; SimpleTuner does well with 3.10 or 3.11. **Python 3.12 should not be used**.
Expand Down Expand Up @@ -117,8 +120,9 @@ There, you will possibly need to modify the following variables:
- `validation_num_inference_steps` - Use somewhere around 50 for the best quality, though you can accept less if you're happy with the results.
- `use_ema` - setting this to `true` will greatly help obtain a more smoothed result alongside your main trained checkpoint.

- `optimizer` - Since Sana requires fp16, some optimisers like `adamw_bf16` will not work with it. You can use `optimi-lion`, `optimi-stableadamw` or others you are familiar with instead.
- `mixed_precision` - This gets overridden to `no` for you anyway, since we rely on fp16 training.
- `optimizer` - You can use any optimiser you are comfortable and familiar with, but we will use `optimi-adamw` for this example.
- `mixed_precision` - It's recommended to set this to `bf16` for the most efficient training configuration, or `no` (but will consume more memory and be slower).
- A value of `fp16` is not recommended here but may be required for certain Sana finetunes (and introduces other new issues to enable this)
- `gradient_checkpointing` - Disabling this will go the fastest, but limits your batch sizes. It is required to enable this to get the lowest VRAM usage.
- `gradient_checkpointing_interval` - If `gradient_checkpointing` feels like overkill on your GPU, you could set this to a value of 2 or higher to only checkpoint every _n_ blocks. A value of 2 would checkpoint half of the blocks, and 3 would be one-third.

Expand Down

0 comments on commit 82aacec

Please sign in to comment.