Skip to content

Commit

Permalink
Merge pull request #1213 from bghira/sana/bf16-weights-fixes
Browse files Browse the repository at this point in the history
sana: use bf16 weights and update class names to latest PR
  • Loading branch information
bghira authored Dec 13, 2024
2 parents 636cd7f + 82aacec commit 4eb7aee
Show file tree
Hide file tree
Showing 15 changed files with 581 additions and 44 deletions.
2 changes: 1 addition & 1 deletion configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
"terminus": "ptx0/terminus-xl-velocity-v2",
"sd3": "stabilityai/stable-diffusion-3.5-large",
"legacy": "stabilityai/stable-diffusion-2-1-base",
"sana": "Efficient-Large-Model/Sana_1600M_1024px_diffusers",
"sana": "terminusresearch/sana-1.6b-1024px",
}

default_cfg = {
Expand Down
16 changes: 10 additions & 6 deletions documentation/quickstart/SANA.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,19 @@ Sana is very lightweight and might not even need full gradient checkpointing ena

Sana is a strange architecture relative to other models that are trainable by SimpleTuner;

- It requires FP16 training, unlike other models, this **will not work** with BF16
- It will not be happy with model quantisation due to the need to run in FP16; most quantisation methods require the use of BF16
- NF4 looks like it might work, but hasn't been fully tested
- SageAttention does not work with Sana due to the shapes inside the model
- Initially, unlike other models, Sana required fp16 training and would crash out with bf16
- Model authors at NVIDIA were gracious enough to follow-up with bf16-compatible weights for fine-tuning
- Quantisation might be more sensitive on this model family due to the issues with bf16/fp16
- SageAttention does not work with Sana (yet) due to its head_dim shape that is currently unsupported
- The loss value when training Sana is very high, and it might need a much lower learning rate than other models (eg. `1e-5` or thereabouts)
- Training might hit NaN values, and it's not clear why this happens

Gradient checkpointing can free VRAM, but slows down training. A chart of test results from a 4090 with 5800X3D:

![image](https://github.com/user-attachments/assets/310bf099-a077-4378-acf4-f60b4b82fdc4)

SimpleTuner's Sana modeling code allows the specification of `--gradient_checkpointing_interval` to checkpoint every _n_ blocks and attain the results seen in the above chart.

### Prerequisites

Make sure that you have python installed; SimpleTuner does well with 3.10 or 3.11. **Python 3.12 should not be used**.
Expand Down Expand Up @@ -117,8 +120,9 @@ There, you will possibly need to modify the following variables:
- `validation_num_inference_steps` - Use somewhere around 50 for the best quality, though you can accept less if you're happy with the results.
- `use_ema` - setting this to `true` will greatly help obtain a more smoothed result alongside your main trained checkpoint.

- `optimizer` - Since Sana requires fp16, some optimisers like `adamw_bf16` will not work with it. You can use `optimi-lion`, `optimi-stableadamw` or others you are familiar with instead.
- `mixed_precision` - This gets overridden to `no` for you anyway, since we rely on fp16 training.
- `optimizer` - You can use any optimiser you are comfortable and familiar with, but we will use `optimi-adamw` for this example.
- `mixed_precision` - It's recommended to set this to `bf16` for the most efficient training configuration, or `no` (but will consume more memory and be slower).
- A value of `fp16` is not recommended here but may be required for certain Sana finetunes (and introduces other new issues to enable this)
- `gradient_checkpointing` - Disabling this will go the fastest, but limits your batch sizes. It is required to enable this to get the lowest VRAM usage.
- `gradient_checkpointing_interval` - If `gradient_checkpointing` feels like overkill on your GPU, you could set this to a value of 2 or higher to only checkpoint every _n_ blocks. A value of 2 would checkpoint half of the blocks, and 3 would be one-third.

Expand Down
2 changes: 1 addition & 1 deletion helpers/caching/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def discover_all_files(self):

def init_vae(self):
if StateTracker.get_args().model_family == "sana":
from diffusers import DCAE as AutoencoderClass
from diffusers import AutoencoderDC as AutoencoderClass
else:
from diffusers import AutoencoderKL as AutoencoderClass

Expand Down
11 changes: 4 additions & 7 deletions helpers/configuration/cmd_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1657,13 +1657,13 @@ def get_argument_parser():
"--mixed_precision",
type=str,
default="bf16",
choices=["bf16", "no"],
choices=["bf16", "fp16", "no"],
help=(
"SimpleTuner only supports bf16 training. Bf16 requires PyTorch >="
" 1.10. on an Nvidia Ampere or later GPU, and PyTorch 2.3 or newer for Apple Silicon."
" Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
" Sana requires a value of 'no'."
" fp16 is offered as an experimental option, but is not recommended as it is less-tested and you will likely encounter errors."
),
)
parser.add_argument(
Expand Down Expand Up @@ -2451,14 +2451,11 @@ def parse_cmdline_args(input_args=None, exit_on_error: bool = False):
args.weight_dtype = (
torch.bfloat16
if (
(args.mixed_precision == "bf16" or torch.backends.mps.is_available())
args.mixed_precision == "bf16"
or (args.base_model_default_dtype == "bf16" and args.is_quantized)
)
else torch.float32
else torch.float16 if args.mixed_precision == "fp16" else torch.float32
)
if args.model_family == "sana":
# god fucking help us, but bf16 does not work with Sana
args.weight_dtype = torch.float16
args.disable_accelerator = os.environ.get("SIMPLETUNER_DISABLE_ACCELERATOR", False)

if "lycoris" == args.lora_type.lower():
Expand Down
Empty file added helpers/models/sana/__init__.py
Empty file.
Loading

0 comments on commit 4eb7aee

Please sign in to comment.