Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sana: use bf16 weights and update class names to latest PR #1213

Merged
merged 4 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
"terminus": "ptx0/terminus-xl-velocity-v2",
"sd3": "stabilityai/stable-diffusion-3.5-large",
"legacy": "stabilityai/stable-diffusion-2-1-base",
"sana": "Efficient-Large-Model/Sana_1600M_1024px_diffusers",
"sana": "terminusresearch/sana-1.6b-1024px",
}

default_cfg = {
Expand Down
16 changes: 10 additions & 6 deletions documentation/quickstart/SANA.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,19 @@ Sana is very lightweight and might not even need full gradient checkpointing ena

Sana is a strange architecture relative to other models that are trainable by SimpleTuner;

- It requires FP16 training, unlike other models, this **will not work** with BF16
- It will not be happy with model quantisation due to the need to run in FP16; most quantisation methods require the use of BF16
- NF4 looks like it might work, but hasn't been fully tested
- SageAttention does not work with Sana due to the shapes inside the model
- Initially, unlike other models, Sana required fp16 training and would crash out with bf16
- Model authors at NVIDIA were gracious enough to follow-up with bf16-compatible weights for fine-tuning
- Quantisation might be more sensitive on this model family due to the issues with bf16/fp16
- SageAttention does not work with Sana (yet) due to its head_dim shape that is currently unsupported
- The loss value when training Sana is very high, and it might need a much lower learning rate than other models (eg. `1e-5` or thereabouts)
- Training might hit NaN values, and it's not clear why this happens

Gradient checkpointing can free VRAM, but slows down training. A chart of test results from a 4090 with 5800X3D:

![image](https://github.com/user-attachments/assets/310bf099-a077-4378-acf4-f60b4b82fdc4)

SimpleTuner's Sana modeling code allows the specification of `--gradient_checkpointing_interval` to checkpoint every _n_ blocks and attain the results seen in the above chart.

### Prerequisites

Make sure that you have python installed; SimpleTuner does well with 3.10 or 3.11. **Python 3.12 should not be used**.
Expand Down Expand Up @@ -117,8 +120,9 @@ There, you will possibly need to modify the following variables:
- `validation_num_inference_steps` - Use somewhere around 50 for the best quality, though you can accept less if you're happy with the results.
- `use_ema` - setting this to `true` will greatly help obtain a more smoothed result alongside your main trained checkpoint.

- `optimizer` - Since Sana requires fp16, some optimisers like `adamw_bf16` will not work with it. You can use `optimi-lion`, `optimi-stableadamw` or others you are familiar with instead.
- `mixed_precision` - This gets overridden to `no` for you anyway, since we rely on fp16 training.
- `optimizer` - You can use any optimiser you are comfortable and familiar with, but we will use `optimi-adamw` for this example.
- `mixed_precision` - It's recommended to set this to `bf16` for the most efficient training configuration, or `no` (but will consume more memory and be slower).
- A value of `fp16` is not recommended here but may be required for certain Sana finetunes (and introduces other new issues to enable this)
- `gradient_checkpointing` - Disabling this will go the fastest, but limits your batch sizes. It is required to enable this to get the lowest VRAM usage.
- `gradient_checkpointing_interval` - If `gradient_checkpointing` feels like overkill on your GPU, you could set this to a value of 2 or higher to only checkpoint every _n_ blocks. A value of 2 would checkpoint half of the blocks, and 3 would be one-third.

Expand Down
2 changes: 1 addition & 1 deletion helpers/caching/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def discover_all_files(self):

def init_vae(self):
if StateTracker.get_args().model_family == "sana":
from diffusers import DCAE as AutoencoderClass
from diffusers import AutoencoderDC as AutoencoderClass
else:
from diffusers import AutoencoderKL as AutoencoderClass

Expand Down
11 changes: 4 additions & 7 deletions helpers/configuration/cmd_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1657,13 +1657,13 @@ def get_argument_parser():
"--mixed_precision",
type=str,
default="bf16",
choices=["bf16", "no"],
choices=["bf16", "fp16", "no"],
help=(
"SimpleTuner only supports bf16 training. Bf16 requires PyTorch >="
" 1.10. on an Nvidia Ampere or later GPU, and PyTorch 2.3 or newer for Apple Silicon."
" Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
" Sana requires a value of 'no'."
" fp16 is offered as an experimental option, but is not recommended as it is less-tested and you will likely encounter errors."
),
)
parser.add_argument(
Expand Down Expand Up @@ -2451,14 +2451,11 @@ def parse_cmdline_args(input_args=None, exit_on_error: bool = False):
args.weight_dtype = (
torch.bfloat16
if (
(args.mixed_precision == "bf16" or torch.backends.mps.is_available())
args.mixed_precision == "bf16"
or (args.base_model_default_dtype == "bf16" and args.is_quantized)
)
else torch.float32
else torch.float16 if args.mixed_precision == "fp16" else torch.float32
)
if args.model_family == "sana":
# god fucking help us, but bf16 does not work with Sana
args.weight_dtype = torch.float16
args.disable_accelerator = os.environ.get("SIMPLETUNER_DISABLE_ACCELERATOR", False)

if "lycoris" == args.lora_type.lower():
Expand Down
Empty file added helpers/models/sana/__init__.py
Empty file.
Loading
Loading