Skip to content

Commit

Permalink
Merge pull request #1182 from bghira/feature/sage-attention
Browse files Browse the repository at this point in the history
Add SageAttention for substantial training speed-up
  • Loading branch information
bghira authored Dec 2, 2024
2 parents e876e74 + 964e065 commit de9f988
Show file tree
Hide file tree
Showing 19 changed files with 327 additions and 38 deletions.
46 changes: 43 additions & 3 deletions OPTIONS.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,14 @@ Carefully answer the questions and use bf16 mixed precision training when prompt

Note that the first several steps of training will be slower than usual because of compilation occuring in the background.

### `--attention_mechanism`

Setting `sageattention` or `xformers` here will allow the use of other memory-efficient attention mechanisms for the forward pass during training and inference, potentially resulting in major performance improvement.

Using `sageattention` enables the use of [SageAttention](https://github.com/thu-ml/SageAttention) on NVIDIA CUDA equipment (sorry, AMD and Apple users).

In simple terms, this will quantise the attention calculations for lower compute and memory overhead, **massively** speeding up training while minimally impacting quality.

---

## 📰 Publishing
Expand Down Expand Up @@ -452,7 +460,8 @@ usage: train.py [-h] [--snr_gamma SNR_GAMMA] [--use_soft_min_snr]
[--lr_scheduler {linear,sine,cosine,cosine_with_restarts,polynomial,constant,constant_with_warmup}]
[--lr_warmup_steps LR_WARMUP_STEPS]
[--lr_num_cycles LR_NUM_CYCLES] [--lr_power LR_POWER]
[--use_ema] [--ema_device {cpu,accelerator}] [--ema_cpu_only]
[--use_ema] [--ema_device {cpu,accelerator}]
[--ema_validation {none,ema_only,comparison}] [--ema_cpu_only]
[--ema_foreach_disable]
[--ema_update_interval EMA_UPDATE_INTERVAL]
[--ema_decay EMA_DECAY] [--non_ema_revision NON_EMA_REVISION]
Expand All @@ -473,8 +482,9 @@ usage: train.py [-h] [--snr_gamma SNR_GAMMA] [--use_soft_min_snr]
[--model_card_safe_for_work] [--logging_dir LOGGING_DIR]
[--benchmark_base_model] [--disable_benchmark]
[--evaluation_type {clip,none}]
[--pretrained_evaluation_model_name_or_path pretrained_evaluation_model_name_or_path]
[--pretrained_evaluation_model_name_or_path PRETRAINED_EVALUATION_MODEL_NAME_OR_PATH]
[--validation_on_startup] [--validation_seed_source {gpu,cpu}]
[--validation_lycoris_strength VALIDATION_LYCORIS_STRENGTH]
[--validation_torch_compile]
[--validation_torch_compile_mode {max-autotune,reduce-overhead,default}]
[--validation_guidance_skip_layers VALIDATION_GUIDANCE_SKIP_LAYERS]
Expand Down Expand Up @@ -509,6 +519,7 @@ usage: train.py [-h] [--snr_gamma SNR_GAMMA] [--use_soft_min_snr]
[--text_encoder_2_precision {no_change,int8-quanto,int4-quanto,int2-quanto,int8-torchao,nf4-bnb,fp8-quanto,fp8uz-quanto}]
[--text_encoder_3_precision {no_change,int8-quanto,int4-quanto,int2-quanto,int8-torchao,nf4-bnb,fp8-quanto,fp8uz-quanto}]
[--local_rank LOCAL_RANK]
[--attention_mechanism {diffusers,xformers,sageattention,sageattention-int8-fp16-triton,sageattention-int8-fp16-cuda,sageattention-int8-fp8-cuda}]
[--enable_xformers_memory_efficient_attention]
[--set_grads_to_none] [--noise_offset NOISE_OFFSET]
[--noise_offset_probability NOISE_OFFSET_PROBABILITY]
Expand Down Expand Up @@ -1137,12 +1148,21 @@ options:
cosine_with_restarts scheduler.
--lr_power LR_POWER Power factor of the polynomial scheduler.
--use_ema Whether to use EMA (exponential moving average) model.
Works with LoRA, Lycoris, and full training.
--ema_device {cpu,accelerator}
The device to use for the EMA model. If set to
'accelerator', the EMA model will be placed on the
accelerator. This provides the fastest EMA update
times, but is not ultimately necessary for EMA to
function.
--ema_validation {none,ema_only,comparison}
When 'none' is set, no EMA validation will be done.
When using 'ema_only', the validations will rely
mostly on the EMA weights. When using 'comparison'
(default) mode, the validations will first run on the
checkpoint before also running for the EMA weights. In
comparison mode, the resulting images will be provided
side-by-side.
--ema_cpu_only When using EMA, the shadow model is moved to the
accelerator before we update its parameters. When
provided, this option will disable the moving of the
Expand Down Expand Up @@ -1248,7 +1268,7 @@ options:
function. The default is to use no evaluator, and
'clip' will use a CLIP model to evaluate the resulting
model's performance during validations.
--pretrained_evaluation_model_name_or_path pretrained_evaluation_model_name_or_path
--pretrained_evaluation_model_name_or_path PRETRAINED_EVALUATION_MODEL_NAME_OR_PATH
Optionally provide a custom model to use for ViT
evaluations. The default is currently clip-vit-large-
patch14-336, allowing for lower patch sizes (greater
Expand All @@ -1264,6 +1284,12 @@ options:
validation errors. If so, please set
SIMPLETUNER_LOG_LEVEL=DEBUG and submit debug.log to a
new Github issue report.
--validation_lycoris_strength VALIDATION_LYCORIS_STRENGTH
When inferencing for validations, the Lycoris model
will by default be run at its training strength, 1.0.
However, this value can be increased to a value of
around 1.3 or 1.5 to get a stronger effect from the
model.
--validation_torch_compile
Supply `--validation_torch_compile=true` to enable the
use of torch.compile() on the validation pipeline. For
Expand Down Expand Up @@ -1453,6 +1479,20 @@ options:
quantisation (Apple Silicon, NVIDIA, AMD).
--local_rank LOCAL_RANK
For distributed training: local_rank
--attention_mechanism {diffusers,xformers,sageattention,sageattention-int8-fp16-triton,sageattention-int8-fp16-cuda,sageattention-int8-fp8-cuda}
On NVIDIA CUDA devices, alternative flash attention
implementations are offered, with the default being
native pytorch SDPA. SageAttention has multiple
backends to select from. The recommended value,
'sageattention', guesses what would be the 'best'
option for SageAttention on your hardware (usually
this is the int8-fp16-cuda backend). However, manually
setting this value to int8-fp16-triton may provide
better averages for per-step training and inference
performance while the cuda backend may provide the
highest maximum speed (with also a lower minimum
speed). NOTE: SageAttention training quality has not
been validated.
--enable_xformers_memory_efficient_attention
Whether or not to use xformers.
--set_grads_to_none Save more memory by using setting grads to None
Expand Down
23 changes: 18 additions & 5 deletions configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,20 @@ def configure_env():
).lower()
== "y"
)
report_to_str = ""

env_contents["--attention_mechanism"] = "diffusers"
use_sageattention = (
prompt_user(
"Would you like to use SageAttention? This is an experimental option that can greatly speed up training. (y/[n])",
"n",
).lower()
== "y"
)
if use_sageattention:
env_contents["--attention_mechanism"] = "sageattention"

# properly disable wandb/tensorboard/comet_ml etc by default
report_to_str = "none"
if report_to_wandb or report_to_tensorboard:
tracker_project_name = prompt_user(
"Enter the name of your Weights & Biases project", f"{model_type}-training"
Expand All @@ -440,17 +453,17 @@ def configure_env():
f"simpletuner-{model_type}",
)
env_contents["--tracker_run_name"] = tracker_run_name
report_to_str = None
if report_to_wandb:
report_to_str = "wandb"
if report_to_tensorboard:
if report_to_wandb:
if report_to_str != "none":
# report to both WandB and Tensorboard if the user wanted.
report_to_str += ","
else:
# remove 'none' from the option
report_to_str = ""
report_to_str += "tensorboard"
if report_to_str:
env_contents["--report_to"] = report_to_str
env_contents["--report_to"] = report_to_str

print_config(env_contents, extra_args)

Expand Down
8 changes: 8 additions & 0 deletions documentation/LYCORIS.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,14 @@ Mandatory fields:

For more information on LyCORIS, please refer to the [documentation in the library](https://github.com/KohakuBlueleaf/LyCORIS/tree/main/docs).

## Potential problems

When using Lycoris on SDXL, it's noted that training the FeedForward modules may break the model and send loss into `NaN` (Not-a-Number) territory.

This seems to be potentially exacerbated when using SageAttention, making it all but guaranteed that the model will immediately fail.

The solution is to remove the `FeedForward` modules from the lycoris config and train only the `Attention` blocks.

## LyCORIS Inference Example

Here is a simple FLUX.1-dev inference script showing how to wrap your unet or transformer with create_lycoris_from_weights and then use it for inference.
Expand Down
10 changes: 10 additions & 0 deletions documentation/quickstart/FLUX.md
Original file line number Diff line number Diff line change
Expand Up @@ -414,9 +414,18 @@ Currently, the lowest VRAM utilisation (9090M) can be attained with:
- DeepSpeed: disabled / unconfigured
- PyTorch: 2.6 Nightly (Sept 29th build)
- Using `--quantize_via=cpu` to avoid outOfMemory error during startup on <=16G cards.
- With `--attention_mechanism=sageattention` to further reduce VRAM by 0.1GB and improve training speed.

Speed was approximately 1.4 iterations per second on a 4090.

### SageAttention

When using `--attention_mechanism=sageattention`, quantised operations are performed during SDPA calculations.

In simpler terms, this can very slightly improve VRAM usage while substantially speeding up training.

**Note**: This isn't compatible with _every_ configuration, but it's worth trying.

### NF4-quantised training

In simplest terms, NF4 is a 4bit-_ish_ representation of the model, which means training has serious stability concerns to address.
Expand All @@ -428,6 +437,7 @@ In early tests, the following holds true:
- NF4, AdamW8bit, and a higher batch size all help to overcome the stability issues, at the cost of more time spent training or VRAM used
- Upping the resolution from 512px to 1024px slows training down from, for example, 1.4 seconds per step to 3.5 seconds per step (batch size of 1, 4090)
- Anything that's difficult to train on int8 or bf16 becomes harder in NF4
- It's less compatible with options like SageAttention

NF4 does not work with torch.compile, so whatever you get for speed is what you get.

Expand Down
8 changes: 8 additions & 0 deletions documentation/quickstart/SD3.md
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,14 @@ These options have been known to keep SD3.5 in-tact for as long as possible:
- DeepSpeed: disabled / unconfigured
- PyTorch: 2.5

### SageAttention

When using `--attention_mechanism=sageattention`, quantised operations are performed during SDPA calculations.

In simpler terms, this can very slightly improve VRAM usage while substantially speeding up training.

**Note**: This isn't compatible with _every_ configuration, but it's worth trying.

### Masked loss

If you are training a subject or style and would like to mask one or the other, see the [masked loss training](/documentation/DREAMBOOTH.md#masked-loss) section of the Dreambooth guide.
Expand Down
8 changes: 8 additions & 0 deletions documentation/quickstart/SIGMA.md
Original file line number Diff line number Diff line change
Expand Up @@ -220,3 +220,11 @@ For more information, see the [dataloader](/documentation/DATALOADER.md) and [tu
### CLIP score tracking

If you wish to enable evaluations to score the model's performance, see [this document](/documentation/evaluation/CLIP_SCORES.md) for information on configuring and interpreting CLIP scores.

### SageAttention

When using `--attention_mechanism=sageattention`, quantised operations are performed during SDPA calculations.

In simpler terms, this can very slightly improve VRAM usage while substantially speeding up training.

**Note**: This isn't compatible with _every_ configuration, but it's worth trying.
44 changes: 39 additions & 5 deletions helpers/configuration/cmd_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1131,7 +1131,7 @@ def get_argument_parser():
" When using 'ema_only', the validations will rely mostly on the EMA weights."
" When using 'comparison' (default) mode, the validations will first run on the checkpoint before also running for"
" the EMA weights. In comparison mode, the resulting images will be provided side-by-side."
)
),
)
parser.add_argument(
"--ema_cpu_only",
Expand Down Expand Up @@ -1708,6 +1708,28 @@ def get_argument_parser():
default=-1,
help="For distributed training: local_rank",
)
parser.add_argument(
"--attention_mechanism",
type=str,
choices=[
"diffusers",
"xformers",
"sageattention",
"sageattention-int8-fp16-triton",
"sageattention-int8-fp16-cuda",
"sageattention-int8-fp8-cuda",
],
default="diffusers",
help=(
"On NVIDIA CUDA devices, alternative flash attention implementations are offered, with the default being native pytorch SDPA."
" SageAttention has multiple backends to select from."
" The recommended value, 'sageattention', guesses what would be the 'best' option for SageAttention on your hardware"
" (usually this is the int8-fp16-cuda backend). However, manually setting this value to int8-fp16-triton"
" may provide better averages for per-step training and inference performance while the cuda backend"
" may provide the highest maximum speed (with also a lower minimum speed). NOTE: SageAttention training quality"
" has not been validated."
),
)
parser.add_argument(
"--enable_xformers_memory_efficient_attention",
action="store_true",
Expand Down Expand Up @@ -2418,7 +2440,7 @@ def parse_cmdline_args(input_args=None):
args.lycoris_config, os.R_OK
):
raise ValueError(
f"Could not find the JSON configuration file at {args.lycoris_config}"
f"Could not find the JSON configuration file at '{args.lycoris_config}'"
)
import json

Expand All @@ -2438,11 +2460,11 @@ def parse_cmdline_args(input_args=None):
elif "standard" == args.lora_type.lower():
if hasattr(args, "lora_init_type") and args.lora_init_type is not None:
if torch.backends.mps.is_available() and args.lora_init_type == "loftq":
logger.error(
error_log(
"Apple MPS cannot make use of LoftQ initialisation. Overriding to 'default'."
)
elif args.is_quantized and args.lora_init_type == "loftq":
logger.error(
error_log(
"LoftQ initialisation is not supported with quantised models. Overriding to 'default'."
)
else:
Expand All @@ -2451,7 +2473,7 @@ def parse_cmdline_args(input_args=None):
)
if args.use_dora:
if "quanto" in args.base_model_precision:
logger.error(
error_log(
"Quanto does not yet support DoRA training in PEFT. Disabling DoRA. 😴"
)
args.use_dora = False
Expand Down Expand Up @@ -2488,4 +2510,16 @@ def parse_cmdline_args(input_args=None):
logger.error(f"Could not load skip layers: {e}")
raise

if args.enable_xformers_memory_efficient_attention:
if args.attention_mechanism != "xformers":
warning_log(
"The option --enable_xformers_memory_efficient_attention is deprecated. Please use --attention_mechanism=xformers instead."
)
args.attention_mechanism = "xformers"

if args.attention_mechanism != "diffusers" and not torch.cuda.is_available():
warning_log(
"For non-CUDA systems, only Diffusers attention mechanism is officially supported."
)

return args
15 changes: 13 additions & 2 deletions helpers/kolors/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1249,11 +1249,22 @@ def denoising_value_valid(dnv):
# unscale/denormalize the latents
latents = latents / self.vae.config.scaling_factor

if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"):
# we have SageAttention loaded. fallback to SDPA for decode.
torch.nn.functional.scaled_dot_product_attention = (
torch.nn.functional.scaled_dot_product_attention_sdpa
)

image = self.vae.decode(
latents.to(device=self.vae.device, dtype=self.vae.dtype),
return_dict=False,
latents.to(device=self.vae.device), return_dict=False
)[0]

if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"):
# reenable SageAttention for training.
torch.nn.functional.scaled_dot_product_attention = (
torch.nn.functional.scaled_dot_product_attention_sage
)

# cast back to fp16 if needed
if needs_upcasting:
self.vae.to(dtype=torch.float16)
Expand Down
19 changes: 14 additions & 5 deletions helpers/legacy/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1160,18 +1160,27 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)

has_nsfw_concept = None
if not output_type == "latent":
if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"):
# we have SageAttention loaded. fallback to SDPA for decode.
torch.nn.functional.scaled_dot_product_attention = (
torch.nn.functional.scaled_dot_product_attention_sdpa
)

image = self.vae.decode(
latents.to(self.vae.dtype) / self.vae.config.scaling_factor,
latents.to(dtype=self.vae.dtype) / self.vae.config.scaling_factor,
return_dict=False,
generator=generator,
)[0]
image, has_nsfw_concept = self.run_safety_checker(
image, device, prompt_embeds.dtype
)

if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"):
# reenable SageAttention for training.
torch.nn.functional.scaled_dot_product_attention = (
torch.nn.functional.scaled_dot_product_attention_sage
)
else:
image = latents
has_nsfw_concept = None

if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
Expand Down
Loading

0 comments on commit de9f988

Please sign in to comment.