From 57ec923caf48aef4f6ece059e59fa386c54d6835 Mon Sep 17 00:00:00 2001 From: bghira Date: Sat, 30 Nov 2024 09:46:36 -0600 Subject: [PATCH 01/18] flux: use sage attention if available --- helpers/models/flux/attention.py | 83 ++++++++++++++++++++++++++++++ helpers/models/flux/transformer.py | 15 +++++- 2 files changed, 97 insertions(+), 1 deletion(-) diff --git a/helpers/models/flux/attention.py b/helpers/models/flux/attention.py index 9009c858..edf2d05b 100644 --- a/helpers/models/flux/attention.py +++ b/helpers/models/flux/attention.py @@ -8,6 +8,12 @@ from flash_attn_interface import flash_attn_func except: pass +try: + from sageattention import sageattn + + F.scaled_dot_product_attention = sageattn +except: + pass def fa3_sdpa( @@ -98,6 +104,83 @@ def __call__( return hidden_states +class FluxSingleSageAttnProcessor3_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn, + hidden_states: Tensor, + encoder_hidden_states: Tensor = None, + attention_mask: FloatTensor = None, + image_rotary_emb: Tensor = None, + ) -> Tensor: + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width + ).transpose(1, 2) + + batch_size, _, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + + query = attn.to_q(hidden_states) + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + # hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = fa3_sdpa(query, key, value) + hidden_states = rearrange(hidden_states, "B H L D -> B L (H D)") + + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + hidden_states = hidden_states.to(query.dtype) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + + return hidden_states + + class FluxAttnProcessor3_0: """Attention processor used typically in processing the SD3-like self-attention projections.""" diff --git a/helpers/models/flux/transformer.py b/helpers/models/flux/transformer.py index 7a9c80f2..806bcc27 100644 --- a/helpers/models/flux/transformer.py +++ b/helpers/models/flux/transformer.py @@ -47,6 +47,15 @@ except: pass +is_sage_attn_available = False +try: + from sageattention import sageattn + + is_sage_attn_available = True + +except: + pass + from helpers.models.flux.attention import ( FluxSingleAttnProcessor3_0, FluxAttnProcessor3_0, @@ -217,7 +226,11 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0): ) primary_device = torch.cuda.get_device_properties(rank) if primary_device.major == 9 and primary_device.minor == 0: - if is_flash_attn_available: + if is_sage_attn_available: + if rank == 0: + print("Using SageAttention for H100 GPU (Single block)") + processor = FluxSingleSageAttnProcessor() + elif is_flash_attn_available: if rank == 0: print("Using FlashAttention3_0 for H100 GPU (Single block)") processor = FluxSingleAttnProcessor3_0() From 68c5a5b6d7e1a49a4f6395c3b0406c882a8ddb14 Mon Sep 17 00:00:00 2001 From: bghira Date: Sat, 30 Nov 2024 09:48:27 -0600 Subject: [PATCH 02/18] update deepspeed call to ensure compliance with new transformers version --- helpers/training/ema.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/helpers/training/ema.py b/helpers/training/ema.py index d2153796..7ec7b991 100644 --- a/helpers/training/ema.py +++ b/helpers/training/ema.py @@ -267,7 +267,7 @@ def step(self, parameters: Iterable[torch.nn.Parameter], global_step: int = None context_manager = contextlib.nullcontext if ( is_transformers_available() - and transformers.deepspeed.is_deepspeed_zero3_enabled() + and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled() ): import deepspeed @@ -309,7 +309,7 @@ def step(self, parameters: Iterable[torch.nn.Parameter], global_step: int = None for s_param, param in zip(self.shadow_params, parameters): if ( is_transformers_available() - and transformers.deepspeed.is_deepspeed_zero3_enabled() + and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled() ): context_manager = deepspeed.zero.GatheredParameters( param, modifier_rank=None @@ -374,7 +374,9 @@ def cuda(self, device=None): def cpu(self): return self.to(device="cpu") - def state_dict(self, destination=None, prefix="", keep_vars=False, exclude_params: bool = False): + def state_dict( + self, destination=None, prefix="", keep_vars=False, exclude_params: bool = False + ): r""" Returns a dictionary containing a whole state of the EMA model. """ From c8fe4e5e8f6e40775136eaa589650bdbb9a96e22 Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 2 Dec 2024 10:36:01 -0600 Subject: [PATCH 03/18] configurator should offer option to enable SageAttention for user --- configure.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/configure.py b/configure.py index f14e13b9..b5c26d28 100644 --- a/configure.py +++ b/configure.py @@ -429,7 +429,20 @@ def configure_env(): ).lower() == "y" ) - report_to_str = "" + + env_contents["--attention_mechanism"] = "diffusers" + use_sageattention = ( + prompt_user( + "Would you like to use SageAttention? This is an experimental option that can greatly speed up training. (y/[n])", + "n", + ).lower() + == "y" + ) + if use_sageattention: + env_contents["--attention_mechanism"] = "sageattention" + + # properly disable wandb/tensorboard/comet_ml etc by default + report_to_str = "none" if report_to_wandb or report_to_tensorboard: tracker_project_name = prompt_user( "Enter the name of your Weights & Biases project", f"{model_type}-training" @@ -440,17 +453,17 @@ def configure_env(): f"simpletuner-{model_type}", ) env_contents["--tracker_run_name"] = tracker_run_name - report_to_str = None if report_to_wandb: report_to_str = "wandb" if report_to_tensorboard: - if report_to_wandb: + if report_to_str != "none": + # report to both WandB and Tensorboard if the user wanted. report_to_str += "," else: + # remove 'none' from the option report_to_str = "" report_to_str += "tensorboard" - if report_to_str: - env_contents["--report_to"] = report_to_str + env_contents["--report_to"] = report_to_str print_config(env_contents, extra_args) From ab02d128f3a0bad4bea1cba559bf23c9506c30d4 Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 2 Dec 2024 10:36:26 -0600 Subject: [PATCH 04/18] add --attention_mechanism option for sageattention --- helpers/configuration/cmd_args.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/helpers/configuration/cmd_args.py b/helpers/configuration/cmd_args.py index 7ea0dc61..1308c50d 100644 --- a/helpers/configuration/cmd_args.py +++ b/helpers/configuration/cmd_args.py @@ -1131,7 +1131,7 @@ def get_argument_parser(): " When using 'ema_only', the validations will rely mostly on the EMA weights." " When using 'comparison' (default) mode, the validations will first run on the checkpoint before also running for" " the EMA weights. In comparison mode, the resulting images will be provided side-by-side." - ) + ), ) parser.add_argument( "--ema_cpu_only", @@ -1708,6 +1708,15 @@ def get_argument_parser(): default=-1, help="For distributed training: local_rank", ) + parser.add_argument( + "--attention_mechanism", + type=str, + choices=["diffusers", "xformers", "sageattention"], + default="diffusers", + help=( + "On NVIDIA CUDA devices, we can use Xformers or SageAttention as an alternative to Pytorch SDPA (Diffusers)." + ), + ) parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", From 2a5e028ed37361df570034efd509f96cb04e0d10 Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 2 Dec 2024 10:37:02 -0600 Subject: [PATCH 05/18] flux: remove model-specific sageattention code --- helpers/models/flux/attention.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/helpers/models/flux/attention.py b/helpers/models/flux/attention.py index edf2d05b..2e880d6e 100644 --- a/helpers/models/flux/attention.py +++ b/helpers/models/flux/attention.py @@ -8,12 +8,6 @@ from flash_attn_interface import flash_attn_func except: pass -try: - from sageattention import sageattn - - F.scaled_dot_product_attention = sageattn -except: - pass def fa3_sdpa( From e6b1919110bf3df2595fb4fff0829461c37c2b38 Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 2 Dec 2024 10:37:34 -0600 Subject: [PATCH 06/18] sageattention cannot be enabled concurrently to xformers --- helpers/training/default_settings/safety_check.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/helpers/training/default_settings/safety_check.py b/helpers/training/default_settings/safety_check.py index f586972c..01444a32 100644 --- a/helpers/training/default_settings/safety_check.py +++ b/helpers/training/default_settings/safety_check.py @@ -116,4 +116,13 @@ def safety_check(args, accelerator): logger.error( f"--flux_schedule_auto_shift cannot be combined with --flux_schedule_shift. Please set --flux_schedule_shift to 0 if you want to train with --flux_schedule_auto_shift." ) - sys.exit(1) \ No newline at end of file + sys.exit(1) + + if ( + args.enable_xformers_memory_efficient_attention + and args.attention_mechanism == "sageattention" + ): + logger.error( + f"--enable_xformers_memory_efficient_attention is only compatible with --attention_mechanism=diffusers. Please set --attention_mechanism=diffusers to enable this feature or disable xformers to use alternative attention mechanisms." + ) + sys.exit(1) From 72e7d9ec29027c476a08d6e6ea116ea9bc83d842 Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 2 Dec 2024 10:38:29 -0600 Subject: [PATCH 07/18] sageattention should overwrite sdpa at startup if enabled --- helpers/training/trainer.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/helpers/training/trainer.py b/helpers/training/trainer.py index a510027d..2aa15894 100644 --- a/helpers/training/trainer.py +++ b/helpers/training/trainer.py @@ -1636,7 +1636,22 @@ def move_models(self, destination: str = "accelerator"): target_device, dtype=self.config.weight_dtype ) ) - if ( + + if self.config.attention_mechanism == "sageattention": + # we'll try and load SageAttention and overload pytorch's sdpa function. + try: + from sageattention import sageattn + + torch.nn.functional.scaled_dot_product_attention = sageattn + logger.warning( + "Using SageAttention for flash attention mechanism. This is an experimental option, and you may receive unexpected or poor results. To disable SageAttention, remove or set --attention_mechanism to a different value." + ) + except ImportError: + logger.error( + "Could not import SageAttention. Please install it to use this --attention_mechanism=sageattention" + ) + sys.exit(1) + elif ( self.config.enable_xformers_memory_efficient_attention and self.config.model_family not in [ From bbf20d0eff804ce5ebff6cadece736782204637a Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 2 Dec 2024 10:39:59 -0600 Subject: [PATCH 08/18] flux: remove more model-specific sageattn code --- helpers/models/flux/attention.py | 77 ------------------------------ helpers/models/flux/transformer.py | 15 +----- 2 files changed, 1 insertion(+), 91 deletions(-) diff --git a/helpers/models/flux/attention.py b/helpers/models/flux/attention.py index 2e880d6e..9009c858 100644 --- a/helpers/models/flux/attention.py +++ b/helpers/models/flux/attention.py @@ -98,83 +98,6 @@ def __call__( return hidden_states -class FluxSingleSageAttnProcessor3_0: - r""" - Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). - """ - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." - ) - - def __call__( - self, - attn, - hidden_states: Tensor, - encoder_hidden_states: Tensor = None, - attention_mask: FloatTensor = None, - image_rotary_emb: Tensor = None, - ) -> Tensor: - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view( - batch_size, channel, height * width - ).transpose(1, 2) - - batch_size, _, _ = ( - hidden_states.shape - if encoder_hidden_states is None - else encoder_hidden_states.shape - ) - - query = attn.to_q(hidden_states) - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # Apply RoPE if needed - if image_rotary_emb is not None: - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - # hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) - hidden_states = fa3_sdpa(query, key, value) - hidden_states = rearrange(hidden_states, "B H L D -> B L (H D)") - - hidden_states = hidden_states.transpose(1, 2).reshape( - batch_size, -1, attn.heads * head_dim - ) - hidden_states = hidden_states.to(query.dtype) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape( - batch_size, channel, height, width - ) - - return hidden_states - - class FluxAttnProcessor3_0: """Attention processor used typically in processing the SD3-like self-attention projections.""" diff --git a/helpers/models/flux/transformer.py b/helpers/models/flux/transformer.py index 806bcc27..7a9c80f2 100644 --- a/helpers/models/flux/transformer.py +++ b/helpers/models/flux/transformer.py @@ -47,15 +47,6 @@ except: pass -is_sage_attn_available = False -try: - from sageattention import sageattn - - is_sage_attn_available = True - -except: - pass - from helpers.models.flux.attention import ( FluxSingleAttnProcessor3_0, FluxAttnProcessor3_0, @@ -226,11 +217,7 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0): ) primary_device = torch.cuda.get_device_properties(rank) if primary_device.major == 9 and primary_device.minor == 0: - if is_sage_attn_available: - if rank == 0: - print("Using SageAttention for H100 GPU (Single block)") - processor = FluxSingleSageAttnProcessor() - elif is_flash_attn_available: + if is_flash_attn_available: if rank == 0: print("Using FlashAttention3_0 for H100 GPU (Single block)") processor = FluxSingleAttnProcessor3_0() From a86c256d4a97819e2475a41d1f1d1f9593e17a4f Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 2 Dec 2024 10:47:45 -0600 Subject: [PATCH 09/18] deprecate --enable_xformers_memory_efficient_attention in favour of --attention_mechanism=xformers --- helpers/configuration/cmd_args.py | 18 +++++++++++++++--- helpers/publishing/metadata.py | 4 +++- helpers/training/trainer.py | 7 +++++-- 3 files changed, 23 insertions(+), 6 deletions(-) diff --git a/helpers/configuration/cmd_args.py b/helpers/configuration/cmd_args.py index 1308c50d..c15cdc01 100644 --- a/helpers/configuration/cmd_args.py +++ b/helpers/configuration/cmd_args.py @@ -2447,11 +2447,11 @@ def parse_cmdline_args(input_args=None): elif "standard" == args.lora_type.lower(): if hasattr(args, "lora_init_type") and args.lora_init_type is not None: if torch.backends.mps.is_available() and args.lora_init_type == "loftq": - logger.error( + error_log( "Apple MPS cannot make use of LoftQ initialisation. Overriding to 'default'." ) elif args.is_quantized and args.lora_init_type == "loftq": - logger.error( + error_log( "LoftQ initialisation is not supported with quantised models. Overriding to 'default'." ) else: @@ -2460,7 +2460,7 @@ def parse_cmdline_args(input_args=None): ) if args.use_dora: if "quanto" in args.base_model_precision: - logger.error( + error_log( "Quanto does not yet support DoRA training in PEFT. Disabling DoRA. 😴" ) args.use_dora = False @@ -2497,4 +2497,16 @@ def parse_cmdline_args(input_args=None): logger.error(f"Could not load skip layers: {e}") raise + if args.enable_xformers_memory_efficient_attention: + if args.attention_mechanism != "xformers": + warning_log( + "The option --enable_xformers_memory_efficient_attention is deprecated. Please use --attention_mechanism=xformers instead." + ) + args.attention_mechanism = "xformers" + + if args.attention_mechanism != "diffusers" and not torch.cuda.is_available(): + warning_log( + "For non-CUDA systems, only Diffusers attention mechanism is officially supported." + ) + return args diff --git a/helpers/publishing/metadata.py b/helpers/publishing/metadata.py index 92bd3568..c761f27e 100644 --- a/helpers/publishing/metadata.py +++ b/helpers/publishing/metadata.py @@ -119,6 +119,7 @@ def ema_info(args): return ema_information return "" + def lycoris_download_info(): """output a function to download the adapter""" output_fn = """ @@ -556,7 +557,8 @@ def save_model_card( - Optimizer: {StateTracker.get_args().optimizer}{optimizer_config if optimizer_config is not None else ''} - Trainable parameter precision: {'Pure BF16' if torch.backends.mps.is_available() or StateTracker.get_args().mixed_precision == "bf16" else 'FP32'} - Caption dropout probability: {StateTracker.get_args().caption_dropout_probability * 100}% -{'- Xformers: Enabled' if StateTracker.get_args().enable_xformers_memory_efficient_attention else ''} +{'- Xformers: Enabled' if StateTracker.get_args().attention_mechanism == 'xformers' else ''} +{'- SageAttention: Enabled' if StateTracker.get_args().attention_mechanism == 'sageattention' else ''} {lora_info(args=StateTracker.get_args())} ## Datasets diff --git a/helpers/training/trainer.py b/helpers/training/trainer.py index 2aa15894..637006d1 100644 --- a/helpers/training/trainer.py +++ b/helpers/training/trainer.py @@ -1652,7 +1652,7 @@ def move_models(self, destination: str = "accelerator"): ) sys.exit(1) elif ( - self.config.enable_xformers_memory_efficient_attention + self.config.attention_mechanism == "xformers" and self.config.model_family not in [ "sd3", @@ -1676,11 +1676,14 @@ def move_models(self, destination: str = "accelerator"): raise ValueError( "xformers is not available. Make sure it is installed correctly" ) - elif self.config.enable_xformers_memory_efficient_attention: + elif self.config.attention_mechanism == "xformers": logger.warning( "xformers is not enabled, as it is incompatible with this model type." + " Falling back to diffusers attention mechanism (Pytorch SDPA)." + " Alternatively, provide --attention_mechanism=sageattention for a more efficient option on CUDA systems." ) self.config.enable_xformers_memory_efficient_attention = False + self.config.attention_mechanism = "diffusers" if self.config.controlnet: self.controlnet.train() From 73c6f1b89d487f73dd72be08e5e60e6503943dec Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 2 Dec 2024 11:44:44 -0600 Subject: [PATCH 10/18] hackish awful workaround for VAE decode in SageAttention --- helpers/models/flux/pipeline.py | 15 +++++++++++-- helpers/models/omnigen/pipeline.py | 17 ++++++++++++--- helpers/models/pixart/pipeline.py | 17 +++++++++++++-- helpers/models/sd3/pipeline.py | 34 ++++++++++++++++++++++++++++-- helpers/models/sdxl/pipeline.py | 12 +++++++++++ 5 files changed, 86 insertions(+), 9 deletions(-) diff --git a/helpers/models/flux/pipeline.py b/helpers/models/flux/pipeline.py index ba7eaa44..1d152def 100644 --- a/helpers/models/flux/pipeline.py +++ b/helpers/models/flux/pipeline.py @@ -906,11 +906,22 @@ def __call__( latents = ( latents / self.vae.config.scaling_factor ) + self.vae.config.shift_factor + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # we have SageAttention loaded. fallback to SDPA for decode. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sdpa + ) image = self.vae.decode( - latents.to(device=self.vae.device, dtype=self.vae.dtype), - return_dict=False, + latents.to(dtype=self.vae.dtype), return_dict=False )[0] + + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # reenable SageAttention for training. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sage + ) + image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models diff --git a/helpers/models/omnigen/pipeline.py b/helpers/models/omnigen/pipeline.py index fadbaf05..5693967f 100644 --- a/helpers/models/omnigen/pipeline.py +++ b/helpers/models/omnigen/pipeline.py @@ -345,9 +345,20 @@ def __call__( ) else: samples = samples / self.vae.config.scaling_factor - samples = self.vae.decode( - samples.to(dtype=self.vae.dtype, device=self.vae.device) - ).sample + + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # we have SageAttention loaded. fallback to SDPA for decode. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sdpa + ) + + image = self.vae.decode(latents.to(dtype=self.vae.dtype), return_dict=False)[0] + + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # reenable SageAttention for training. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sage + ) if self.model_cpu_offload: self.vae.to("cpu") diff --git a/helpers/models/pixart/pipeline.py b/helpers/models/pixart/pipeline.py index 6efb54cd..244df2e2 100644 --- a/helpers/models/pixart/pipeline.py +++ b/helpers/models/pixart/pipeline.py @@ -1231,11 +1231,24 @@ def denoising_value_valid(dnv): callback(step_idx, t, latents) if not output_type == "latent": + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # we have SageAttention loaded. fallback to SDPA for decode. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sdpa + ) + image = self.vae.decode( - latents.to(device=self.vae.device, dtype=self.vae.dtype) - / self.vae.config.scaling_factor, + latents.to(dtype=self.vae.dtype) / self.vae.config.scaling_factor, return_dict=False, + generator=generator, )[0] + + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # reenable SageAttention for training. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sage + ) + if use_resolution_binning: image = self.image_processor.resize_and_crop_tensor( image, orig_width, orig_height diff --git a/helpers/models/sd3/pipeline.py b/helpers/models/sd3/pipeline.py index 653c2a6a..ed25c954 100644 --- a/helpers/models/sd3/pipeline.py +++ b/helpers/models/sd3/pipeline.py @@ -1097,7 +1097,22 @@ def __call__( latents / self.vae.config.scaling_factor ) + self.vae.config.shift_factor - image = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0] + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # we have SageAttention loaded. fallback to SDPA for decode. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sdpa + ) + + image = self.vae.decode( + latents.to(dtype=self.vae.dtype), return_dict=False + )[0] + + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # reenable SageAttention for training. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sage + ) + image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models @@ -2053,7 +2068,22 @@ def __call__( latents / self.vae.config.scaling_factor ) + self.vae.config.shift_factor - image = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0] + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # we have SageAttention loaded. fallback to SDPA for decode. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sdpa + ) + + image = self.vae.decode( + latents.to(dtype=self.vae.dtype), return_dict=False + )[0] + + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # reenable SageAttention for training. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sage + ) + image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models diff --git a/helpers/models/sdxl/pipeline.py b/helpers/models/sdxl/pipeline.py index bfd93f29..10b01216 100644 --- a/helpers/models/sdxl/pipeline.py +++ b/helpers/models/sdxl/pipeline.py @@ -1488,10 +1488,22 @@ def __call__( else: latents = latents / self.vae.config.scaling_factor + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # we have SageAttention loaded. fallback to SDPA for decode. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sdpa + ) + image = self.vae.decode( latents.to(dtype=self.vae.dtype), return_dict=False )[0] + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # reenable SageAttention for training. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sage + ) + # cast back to fp16 if needed if needs_upcasting: self.vae.to(dtype=torch.float16) From d406e3a817e57411df8afd9153e0c834a6a3fbb8 Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 2 Dec 2024 11:45:34 -0600 Subject: [PATCH 11/18] add more sageattention API choices --- helpers/configuration/cmd_args.py | 11 +++++++-- helpers/training/trainer.py | 40 ++++++++++++++++++++++++++----- 2 files changed, 43 insertions(+), 8 deletions(-) diff --git a/helpers/configuration/cmd_args.py b/helpers/configuration/cmd_args.py index c15cdc01..a0617f80 100644 --- a/helpers/configuration/cmd_args.py +++ b/helpers/configuration/cmd_args.py @@ -1711,7 +1711,14 @@ def get_argument_parser(): parser.add_argument( "--attention_mechanism", type=str, - choices=["diffusers", "xformers", "sageattention"], + choices=[ + "diffusers", + "xformers", + "sageattention", + "sageattention-int8-fp16-triton", + "sageattention-int8-fp16-cuda", + "sageattention-int8-fp8-cuda", + ], default="diffusers", help=( "On NVIDIA CUDA devices, we can use Xformers or SageAttention as an alternative to Pytorch SDPA (Diffusers)." @@ -2427,7 +2434,7 @@ def parse_cmdline_args(input_args=None): args.lycoris_config, os.R_OK ): raise ValueError( - f"Could not find the JSON configuration file at {args.lycoris_config}" + f"Could not find the JSON configuration file at '{args.lycoris_config}'" ) import json diff --git a/helpers/training/trainer.py b/helpers/training/trainer.py index 637006d1..a78ede95 100644 --- a/helpers/training/trainer.py +++ b/helpers/training/trainer.py @@ -1637,19 +1637,47 @@ def move_models(self, destination: str = "accelerator"): ) ) - if self.config.attention_mechanism == "sageattention": + if "sageattention" in self.config.attention_mechanism: # we'll try and load SageAttention and overload pytorch's sdpa function. try: - from sageattention import sageattn + from sageattention import ( + sageattn, + sageattn_qk_int8_pv_fp16_triton, + sageattn_qk_int8_pv_fp16_cuda, + sageattn_qk_int8_pv_fp8_cuda, + ) + + sageattn_functions = { + "sageattention": sageattn, + "sageattention-int8-fp16-triton": sageattn_qk_int8_pv_fp16_triton, + "sageattention-int8-fp16-cuda": sageattn_qk_int8_pv_fp16_cuda, + "sageattention-int8-fp8-cuda": sageattn_qk_int8_pv_fp8_cuda, + } + # store the old SDPA for validations to use during VAE decode + setattr( + torch.nn.functional, + "scaled_dot_product_attention_sdpa", + torch.nn.functional.scaled_dot_product_attention, + ) + torch.nn.functional.scaled_dot_product_attention = ( + sageattn_functions.get( + self.config.attention_mechanism, "sageattention" + ) + ) + setattr( + torch.nn.functional, + "scaled_dot_product_attention_sage", + torch.nn.functional.scaled_dot_product_attention, + ) - torch.nn.functional.scaled_dot_product_attention = sageattn logger.warning( - "Using SageAttention for flash attention mechanism. This is an experimental option, and you may receive unexpected or poor results. To disable SageAttention, remove or set --attention_mechanism to a different value." + f"Using {self.config.attention_mechanism} for flash attention mechanism. This is an experimental option, and you may receive unexpected or poor results. To disable SageAttention, remove or set --attention_mechanism to a different value." ) - except ImportError: + except ImportError as e: logger.error( - "Could not import SageAttention. Please install it to use this --attention_mechanism=sageattention" + "Could not import SageAttention. Please install it to use this --attention_mechanism=sageattention." ) + logger.error(repr(e)) sys.exit(1) elif ( self.config.attention_mechanism == "xformers" From cb286c2dcca43694e6a296f44ca140bc040dc9e6 Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 2 Dec 2024 11:47:42 -0600 Subject: [PATCH 12/18] SD 1.5/2.x fix for SageAttention decode --- helpers/legacy/pipeline.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/helpers/legacy/pipeline.py b/helpers/legacy/pipeline.py index 01b0ffe8..c186b48d 100644 --- a/helpers/legacy/pipeline.py +++ b/helpers/legacy/pipeline.py @@ -1160,18 +1160,27 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + has_nsfw_concept = None if not output_type == "latent": + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # we have SageAttention loaded. fallback to SDPA for decode. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sdpa + ) + image = self.vae.decode( - latents.to(self.vae.dtype) / self.vae.config.scaling_factor, + latents.to(dtype=self.vae.dtype) / self.vae.config.scaling_factor, return_dict=False, generator=generator, )[0] - image, has_nsfw_concept = self.run_safety_checker( - image, device, prompt_embeds.dtype - ) + + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # reenable SageAttention for training. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sage + ) else: image = latents - has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] From 11b3cf175607c61e4f09ab989eb648b0698fec51 Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 2 Dec 2024 11:48:24 -0600 Subject: [PATCH 13/18] set attention mechanism to the default for tests --- tests/test_model_card.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_model_card.py b/tests/test_model_card.py index 1c596be5..2e2ed5b1 100644 --- a/tests/test_model_card.py +++ b/tests/test_model_card.py @@ -65,6 +65,7 @@ def setUp(self): self.args.flux_guidance_value = 1.0 self.args.t5_padding = "unmodified" self.args.enable_xformers_memory_efficient_attention = False + self.args.attention_mechanism = "diffusers" def test_model_imports(self): self.args.lora_type = "standard" From d1c227cc71af063615ff1357fa9c196765ddf1df Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 2 Dec 2024 12:14:56 -0600 Subject: [PATCH 14/18] add sageattention to OPTIONS doc, update recommendations in --help output --- OPTIONS.md | 38 ++++++++++++++++++++++++++++--- helpers/configuration/cmd_args.py | 8 ++++++- 2 files changed, 42 insertions(+), 4 deletions(-) diff --git a/OPTIONS.md b/OPTIONS.md index edeea098..c3d59c8d 100644 --- a/OPTIONS.md +++ b/OPTIONS.md @@ -452,7 +452,8 @@ usage: train.py [-h] [--snr_gamma SNR_GAMMA] [--use_soft_min_snr] [--lr_scheduler {linear,sine,cosine,cosine_with_restarts,polynomial,constant,constant_with_warmup}] [--lr_warmup_steps LR_WARMUP_STEPS] [--lr_num_cycles LR_NUM_CYCLES] [--lr_power LR_POWER] - [--use_ema] [--ema_device {cpu,accelerator}] [--ema_cpu_only] + [--use_ema] [--ema_device {cpu,accelerator}] + [--ema_validation {none,ema_only,comparison}] [--ema_cpu_only] [--ema_foreach_disable] [--ema_update_interval EMA_UPDATE_INTERVAL] [--ema_decay EMA_DECAY] [--non_ema_revision NON_EMA_REVISION] @@ -473,8 +474,9 @@ usage: train.py [-h] [--snr_gamma SNR_GAMMA] [--use_soft_min_snr] [--model_card_safe_for_work] [--logging_dir LOGGING_DIR] [--benchmark_base_model] [--disable_benchmark] [--evaluation_type {clip,none}] - [--pretrained_evaluation_model_name_or_path pretrained_evaluation_model_name_or_path] + [--pretrained_evaluation_model_name_or_path PRETRAINED_EVALUATION_MODEL_NAME_OR_PATH] [--validation_on_startup] [--validation_seed_source {gpu,cpu}] + [--validation_lycoris_strength VALIDATION_LYCORIS_STRENGTH] [--validation_torch_compile] [--validation_torch_compile_mode {max-autotune,reduce-overhead,default}] [--validation_guidance_skip_layers VALIDATION_GUIDANCE_SKIP_LAYERS] @@ -509,6 +511,7 @@ usage: train.py [-h] [--snr_gamma SNR_GAMMA] [--use_soft_min_snr] [--text_encoder_2_precision {no_change,int8-quanto,int4-quanto,int2-quanto,int8-torchao,nf4-bnb,fp8-quanto,fp8uz-quanto}] [--text_encoder_3_precision {no_change,int8-quanto,int4-quanto,int2-quanto,int8-torchao,nf4-bnb,fp8-quanto,fp8uz-quanto}] [--local_rank LOCAL_RANK] + [--attention_mechanism {diffusers,xformers,sageattention,sageattention-int8-fp16-triton,sageattention-int8-fp16-cuda,sageattention-int8-fp8-cuda}] [--enable_xformers_memory_efficient_attention] [--set_grads_to_none] [--noise_offset NOISE_OFFSET] [--noise_offset_probability NOISE_OFFSET_PROBABILITY] @@ -1137,12 +1140,21 @@ options: cosine_with_restarts scheduler. --lr_power LR_POWER Power factor of the polynomial scheduler. --use_ema Whether to use EMA (exponential moving average) model. + Works with LoRA, Lycoris, and full training. --ema_device {cpu,accelerator} The device to use for the EMA model. If set to 'accelerator', the EMA model will be placed on the accelerator. This provides the fastest EMA update times, but is not ultimately necessary for EMA to function. + --ema_validation {none,ema_only,comparison} + When 'none' is set, no EMA validation will be done. + When using 'ema_only', the validations will rely + mostly on the EMA weights. When using 'comparison' + (default) mode, the validations will first run on the + checkpoint before also running for the EMA weights. In + comparison mode, the resulting images will be provided + side-by-side. --ema_cpu_only When using EMA, the shadow model is moved to the accelerator before we update its parameters. When provided, this option will disable the moving of the @@ -1248,7 +1260,7 @@ options: function. The default is to use no evaluator, and 'clip' will use a CLIP model to evaluate the resulting model's performance during validations. - --pretrained_evaluation_model_name_or_path pretrained_evaluation_model_name_or_path + --pretrained_evaluation_model_name_or_path PRETRAINED_EVALUATION_MODEL_NAME_OR_PATH Optionally provide a custom model to use for ViT evaluations. The default is currently clip-vit-large- patch14-336, allowing for lower patch sizes (greater @@ -1264,6 +1276,12 @@ options: validation errors. If so, please set SIMPLETUNER_LOG_LEVEL=DEBUG and submit debug.log to a new Github issue report. + --validation_lycoris_strength VALIDATION_LYCORIS_STRENGTH + When inferencing for validations, the Lycoris model + will by default be run at its training strength, 1.0. + However, this value can be increased to a value of + around 1.3 or 1.5 to get a stronger effect from the + model. --validation_torch_compile Supply `--validation_torch_compile=true` to enable the use of torch.compile() on the validation pipeline. For @@ -1453,6 +1471,20 @@ options: quantisation (Apple Silicon, NVIDIA, AMD). --local_rank LOCAL_RANK For distributed training: local_rank + --attention_mechanism {diffusers,xformers,sageattention,sageattention-int8-fp16-triton,sageattention-int8-fp16-cuda,sageattention-int8-fp8-cuda} + On NVIDIA CUDA devices, alternative flash attention + implementations are offered, with the default being + native pytorch SDPA. SageAttention has multiple + backends to select from. The recommended value, + 'sageattention', guesses what would be the 'best' + option for SageAttention on your hardware (usually + this is the int8-fp16-cuda backend). However, manually + setting this value to int8-fp16-triton may provide + better averages for per-step training and inference + performance while the cuda backend may provide the + highest maximum speed (with also a lower minimum + speed). NOTE: SageAttention training quality has not + been validated. --enable_xformers_memory_efficient_attention Whether or not to use xformers. --set_grads_to_none Save more memory by using setting grads to None diff --git a/helpers/configuration/cmd_args.py b/helpers/configuration/cmd_args.py index a0617f80..fce256e3 100644 --- a/helpers/configuration/cmd_args.py +++ b/helpers/configuration/cmd_args.py @@ -1721,7 +1721,13 @@ def get_argument_parser(): ], default="diffusers", help=( - "On NVIDIA CUDA devices, we can use Xformers or SageAttention as an alternative to Pytorch SDPA (Diffusers)." + "On NVIDIA CUDA devices, alternative flash attention implementations are offered, with the default being native pytorch SDPA." + " SageAttention has multiple backends to select from." + " The recommended value, 'sageattention', guesses what would be the 'best' option for SageAttention on your hardware" + " (usually this is the int8-fp16-cuda backend). However, manually setting this value to int8-fp16-triton" + " may provide better averages for per-step training and inference performance while the cuda backend" + " may provide the highest maximum speed (with also a lower minimum speed). NOTE: SageAttention training quality" + " has not been validated." ), ) parser.add_argument( From 0976b5ef22705078bc666a32ce671f2919c27a03 Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 2 Dec 2024 12:32:52 -0600 Subject: [PATCH 15/18] kolors: enable vae decode hack for sageattention --- helpers/kolors/pipeline.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/helpers/kolors/pipeline.py b/helpers/kolors/pipeline.py index 964a8e95..2f395e4d 100644 --- a/helpers/kolors/pipeline.py +++ b/helpers/kolors/pipeline.py @@ -1249,11 +1249,22 @@ def denoising_value_valid(dnv): # unscale/denormalize the latents latents = latents / self.vae.config.scaling_factor + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # we have SageAttention loaded. fallback to SDPA for decode. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sdpa + ) + image = self.vae.decode( - latents.to(device=self.vae.device, dtype=self.vae.dtype), - return_dict=False, + latents.to(device=self.vae.device), return_dict=False )[0] + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # reenable SageAttention for training. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sage + ) + # cast back to fp16 if needed if needs_upcasting: self.vae.to(dtype=torch.float16) From c98d0f7a9d49c08a146732c562d22d57e5c989f1 Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 2 Dec 2024 12:40:42 -0600 Subject: [PATCH 16/18] add sageattention to the quickstart docs and a specific section under "Precision" in the OPTIONS doc --- OPTIONS.md | 8 ++++++++ documentation/LYCORIS.md | 8 ++++++++ documentation/quickstart/FLUX.md | 10 ++++++++++ documentation/quickstart/SD3.md | 8 ++++++++ documentation/quickstart/SIGMA.md | 8 ++++++++ 5 files changed, 42 insertions(+) diff --git a/OPTIONS.md b/OPTIONS.md index c3d59c8d..ddc32383 100644 --- a/OPTIONS.md +++ b/OPTIONS.md @@ -109,6 +109,14 @@ Carefully answer the questions and use bf16 mixed precision training when prompt Note that the first several steps of training will be slower than usual because of compilation occuring in the background. +### `--attention_mechanism` + +Setting `sageattention` or `xformers` here will allow the use of other memory-efficient attention mechanisms for the forward pass during training and inference, potentially resulting in major performance improvement. + +Using `sageattention` enables the use of [SageAttention](https://github.com/thu-ml/SageAttention) on NVIDIA CUDA equipment (sorry, AMD and Apple users). + +In simple terms, this will quantise the attention calculations for lower compute and memory overhead, **massively** speeding up training while minimally impacting quality. + --- ## 📰 Publishing diff --git a/documentation/LYCORIS.md b/documentation/LYCORIS.md index 64aba565..853d3ec6 100644 --- a/documentation/LYCORIS.md +++ b/documentation/LYCORIS.md @@ -59,6 +59,14 @@ Mandatory fields: For more information on LyCORIS, please refer to the [documentation in the library](https://github.com/KohakuBlueleaf/LyCORIS/tree/main/docs). +## Potential problems + +When using Lycoris on SDXL, it's noted that training the FeedForward modules may break the model and send loss into `NaN` (Not-a-Number) territory. + +This seems to be potentially exacerbated when using SageAttention, making it all but guaranteed that the model will immediately fail. + +The solution is to remove the `FeedForward` modules from the lycoris config and train only the `Attention` blocks. + ## LyCORIS Inference Example Here is a simple FLUX.1-dev inference script showing how to wrap your unet or transformer with create_lycoris_from_weights and then use it for inference. diff --git a/documentation/quickstart/FLUX.md b/documentation/quickstart/FLUX.md index ddb86e1d..189b1a72 100644 --- a/documentation/quickstart/FLUX.md +++ b/documentation/quickstart/FLUX.md @@ -414,9 +414,18 @@ Currently, the lowest VRAM utilisation (9090M) can be attained with: - DeepSpeed: disabled / unconfigured - PyTorch: 2.6 Nightly (Sept 29th build) - Using `--quantize_via=cpu` to avoid outOfMemory error during startup on <=16G cards. +- With `--attention_mechanism=sageattention` to further reduce VRAM by 0.1GB and improve training speed. Speed was approximately 1.4 iterations per second on a 4090. +### SageAttention + +When using `--attention_mechanism=sageattention`, quantised operations are performed during SDPA calculations. + +In simpler terms, this can very slightly improve VRAM usage while substantially speeding up training. + +**Note**: This isn't compatible with _every_ configuration, but it's worth trying. + ### NF4-quantised training In simplest terms, NF4 is a 4bit-_ish_ representation of the model, which means training has serious stability concerns to address. @@ -428,6 +437,7 @@ In early tests, the following holds true: - NF4, AdamW8bit, and a higher batch size all help to overcome the stability issues, at the cost of more time spent training or VRAM used - Upping the resolution from 512px to 1024px slows training down from, for example, 1.4 seconds per step to 3.5 seconds per step (batch size of 1, 4090) - Anything that's difficult to train on int8 or bf16 becomes harder in NF4 +- It's less compatible with options like SageAttention NF4 does not work with torch.compile, so whatever you get for speed is what you get. diff --git a/documentation/quickstart/SD3.md b/documentation/quickstart/SD3.md index 5d646562..764bf64a 100644 --- a/documentation/quickstart/SD3.md +++ b/documentation/quickstart/SD3.md @@ -339,6 +339,14 @@ These options have been known to keep SD3.5 in-tact for as long as possible: - DeepSpeed: disabled / unconfigured - PyTorch: 2.5 +### SageAttention + +When using `--attention_mechanism=sageattention`, quantised operations are performed during SDPA calculations. + +In simpler terms, this can very slightly improve VRAM usage while substantially speeding up training. + +**Note**: This isn't compatible with _every_ configuration, but it's worth trying. + ### Masked loss If you are training a subject or style and would like to mask one or the other, see the [masked loss training](/documentation/DREAMBOOTH.md#masked-loss) section of the Dreambooth guide. diff --git a/documentation/quickstart/SIGMA.md b/documentation/quickstart/SIGMA.md index cf3389aa..aadd46f2 100644 --- a/documentation/quickstart/SIGMA.md +++ b/documentation/quickstart/SIGMA.md @@ -220,3 +220,11 @@ For more information, see the [dataloader](/documentation/DATALOADER.md) and [tu ### CLIP score tracking If you wish to enable evaluations to score the model's performance, see [this document](/documentation/evaluation/CLIP_SCORES.md) for information on configuring and interpreting CLIP scores. + +### SageAttention + +When using `--attention_mechanism=sageattention`, quantised operations are performed during SDPA calculations. + +In simpler terms, this can very slightly improve VRAM usage while substantially speeding up training. + +**Note**: This isn't compatible with _every_ configuration, but it's worth trying. \ No newline at end of file From 8719e738fea3964868d5bcdbffa57e6faddb9c2b Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 2 Dec 2024 13:05:48 -0600 Subject: [PATCH 17/18] sd3 skip_layer_guidance fix from upstream for num images per prompt > 1 --- helpers/models/sd3/pipeline.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/helpers/models/sd3/pipeline.py b/helpers/models/sd3/pipeline.py index ed25c954..70698f6e 100644 --- a/helpers/models/sd3/pipeline.py +++ b/helpers/models/sd3/pipeline.py @@ -999,11 +999,14 @@ def __call__( continue # expand the latents if we are doing classifier free guidance + # added fix from: https://github.com/huggingface/diffusers/pull/10086/files + # to allow for num_images_per_prompt > 1 latent_model_input = ( torch.cat([latents] * 2) - if self.do_classifier_free_guidance and skip_guidance_layers is None + if self.do_classifier_free_guidance else latents ) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) @@ -1033,6 +1036,8 @@ def __call__( else False ) if skip_guidance_layers is not None and should_skip_layers: + timestep = t.expand(latents.shape[0]) + latent_model_input = latents noise_pred_skip_layers = self.transformer( hidden_states=latent_model_input.to( device=self.transformer.device, From 964e065336eb384e1ad90ad05297990a6d91d17a Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 2 Dec 2024 13:41:07 -0600 Subject: [PATCH 18/18] disable nf4 + sageattention --- helpers/training/default_settings/safety_check.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/helpers/training/default_settings/safety_check.py b/helpers/training/default_settings/safety_check.py index 01444a32..518afd63 100644 --- a/helpers/training/default_settings/safety_check.py +++ b/helpers/training/default_settings/safety_check.py @@ -126,3 +126,9 @@ def safety_check(args, accelerator): f"--enable_xformers_memory_efficient_attention is only compatible with --attention_mechanism=diffusers. Please set --attention_mechanism=diffusers to enable this feature or disable xformers to use alternative attention mechanisms." ) sys.exit(1) + + if "nf4" in args.base_model_precision: + logger.error( + f"{args.base_model_precision} is not supported with SageAttention. Please select from int8 or fp8, or, disable quantisation to use SageAttention." + ) + sys.exit(1)