From e9614f5521df4a969e9b9360fde1f1a9f0b0d8bf Mon Sep 17 00:00:00 2001 From: bghira Date: Fri, 13 Dec 2024 06:35:14 -0600 Subject: [PATCH 1/4] sana: work with bf16 weights --- helpers/caching/vae.py | 2 +- helpers/configuration/cmd_args.py | 11 +- helpers/models/sana/__init__.py | 0 helpers/models/sana/transformer.py | 504 ++++++++++++++++++ helpers/publishing/metadata.py | 2 +- .../training/default_settings/safety_check.py | 4 +- helpers/training/diffusion_model.py | 20 +- helpers/training/trainer.py | 12 +- helpers/training/validation.py | 3 +- install/apple/poetry.lock | 2 +- install/rocm/poetry.lock | 22 +- install/rocm/pyproject.toml | 2 +- poetry.lock | 2 +- 13 files changed, 549 insertions(+), 37 deletions(-) create mode 100644 helpers/models/sana/__init__.py create mode 100644 helpers/models/sana/transformer.py diff --git a/helpers/caching/vae.py b/helpers/caching/vae.py index c0d3f728..28823e1e 100644 --- a/helpers/caching/vae.py +++ b/helpers/caching/vae.py @@ -254,7 +254,7 @@ def discover_all_files(self): def init_vae(self): if StateTracker.get_args().model_family == "sana": - from diffusers import DCAE as AutoencoderClass + from diffusers import AutoencoderDC as AutoencoderClass else: from diffusers import AutoencoderKL as AutoencoderClass diff --git a/helpers/configuration/cmd_args.py b/helpers/configuration/cmd_args.py index 6c3839ee..97ea5589 100644 --- a/helpers/configuration/cmd_args.py +++ b/helpers/configuration/cmd_args.py @@ -1657,13 +1657,13 @@ def get_argument_parser(): "--mixed_precision", type=str, default="bf16", - choices=["bf16", "no"], + choices=["bf16", "fp16", "no"], help=( "SimpleTuner only supports bf16 training. Bf16 requires PyTorch >=" " 1.10. on an Nvidia Ampere or later GPU, and PyTorch 2.3 or newer for Apple Silicon." " Default to the value of accelerate config of the current system or the" " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." - " Sana requires a value of 'no'." + " fp16 is offered as an experimental option, but is not recommended as it is less-tested and you will likely encounter errors." ), ) parser.add_argument( @@ -2451,14 +2451,11 @@ def parse_cmdline_args(input_args=None, exit_on_error: bool = False): args.weight_dtype = ( torch.bfloat16 if ( - (args.mixed_precision == "bf16" or torch.backends.mps.is_available()) + args.mixed_precision == "bf16" or (args.base_model_default_dtype == "bf16" and args.is_quantized) ) - else torch.float32 + else torch.float16 if args.mixed_precision == "fp16" else torch.float32 ) - if args.model_family == "sana": - # god fucking help us, but bf16 does not work with Sana - args.weight_dtype = torch.float16 args.disable_accelerator = os.environ.get("SIMPLETUNER_DISABLE_ACCELERATOR", False) if "lycoris" == args.lora_type.lower(): diff --git a/helpers/models/sana/__init__.py b/helpers/models/sana/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/helpers/models/sana/transformer.py b/helpers/models/sana/transformer.py new file mode 100644 index 00000000..2243a9b9 --- /dev/null +++ b/helpers/models/sana/transformer.py @@ -0,0 +1,504 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Optional, Tuple, Union + +import torch +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import is_torch_version, logging +from diffusers.models.attention_processor import ( + Attention, + AttentionProcessor, + AttnProcessor2_0, + SanaLinearAttnProcessor2_0, +) +from diffusers.models.embeddings import PatchEmbed, PixArtAlphaTextProjection +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import AdaLayerNormSingle, RMSNorm + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class GLUMBConv(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + expand_ratio: float = 4, + norm_type: Optional[str] = None, + residual_connection: bool = True, + ) -> None: + super().__init__() + + hidden_channels = int(expand_ratio * in_channels) + self.norm_type = norm_type + self.residual_connection = residual_connection + + self.nonlinearity = nn.SiLU() + self.conv_inverted = nn.Conv2d(in_channels, hidden_channels * 2, 1, 1, 0) + self.conv_depth = nn.Conv2d( + hidden_channels * 2, + hidden_channels * 2, + 3, + 1, + 1, + groups=hidden_channels * 2, + ) + self.conv_point = nn.Conv2d(hidden_channels, out_channels, 1, 1, 0, bias=False) + + self.norm = None + if norm_type == "rms_norm": + self.norm = RMSNorm( + out_channels, eps=1e-5, elementwise_affine=True, bias=True + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if self.residual_connection: + residual = hidden_states + + hidden_states = self.conv_inverted(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.conv_depth(hidden_states) + hidden_states, gate = torch.chunk(hidden_states, 2, dim=1) + hidden_states = hidden_states * self.nonlinearity(gate) + + hidden_states = self.conv_point(hidden_states) + + if self.norm_type == "rms_norm": + # move channel to the last dimension so we apply RMSnorm across channel dimension + hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1) + + if self.residual_connection: + hidden_states = hidden_states + residual + + return hidden_states + + +class SanaTransformerBlock(nn.Module): + r""" + Transformer block introduced in [Sana](https://huggingface.co/papers/2410.10629). + """ + + def __init__( + self, + dim: int = 2240, + num_attention_heads: int = 70, + attention_head_dim: int = 32, + dropout: float = 0.0, + num_cross_attention_heads: Optional[int] = 20, + cross_attention_head_dim: Optional[int] = 112, + cross_attention_dim: Optional[int] = 2240, + attention_bias: bool = True, + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-6, + attention_out_bias: bool = True, + mlp_ratio: float = 2.5, + ) -> None: + super().__init__() + + # 1. Self Attention + self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=norm_eps) + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=None, + processor=SanaLinearAttnProcessor2_0(), + ) + + # 2. Cross Attention + if cross_attention_dim is not None: + self.norm2 = nn.LayerNorm( + dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps + ) + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_cross_attention_heads, + dim_head=cross_attention_head_dim, + dropout=dropout, + bias=True, + out_bias=attention_out_bias, + processor=AttnProcessor2_0(), + ) + + # 3. Feed-forward + self.ff = GLUMBConv( + dim, dim, mlp_ratio, norm_type=None, residual_connection=False + ) + + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + height: int = None, + width: int = None, + ) -> torch.Tensor: + batch_size = hidden_states.shape[0] + + # 1. Modulation + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + + # 2. Self Attention + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + norm_hidden_states = norm_hidden_states.to(hidden_states.dtype) + + attn_output = self.attn1(norm_hidden_states) + hidden_states = hidden_states + gate_msa * attn_output + + # 3. Cross Attention + if self.attn2 is not None: + attn_output = self.attn2( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + + norm_hidden_states = norm_hidden_states.unflatten(1, (height, width)).permute( + 0, 3, 1, 2 + ) + ff_output = self.ff(norm_hidden_states) + ff_output = ff_output.flatten(2, 3).permute(0, 2, 1) + hidden_states = hidden_states + gate_mlp * ff_output + + return hidden_states + + +class SanaTransformer2DModel(ModelMixin, ConfigMixin): + r""" + A 2D Transformer model introduced in [Sana](https://huggingface.co/papers/2410.10629) family of models. + + Args: + in_channels (`int`, defaults to `32`): + The number of channels in the input. + out_channels (`int`, *optional*, defaults to `32`): + The number of channels in the output. + num_attention_heads (`int`, defaults to `70`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`, defaults to `32`): + The number of channels in each head. + num_layers (`int`, defaults to `20`): + The number of layers of Transformer blocks to use. + num_cross_attention_heads (`int`, *optional*, defaults to `20`): + The number of heads to use for cross-attention. + cross_attention_head_dim (`int`, *optional*, defaults to `112`): + The number of channels in each head for cross-attention. + cross_attention_dim (`int`, *optional*, defaults to `2240`): + The number of channels in the cross-attention output. + caption_channels (`int`, defaults to `2304`): + The number of channels in the caption embeddings. + mlp_ratio (`float`, defaults to `2.5`): + The expansion ratio to use in the GLUMBConv layer. + dropout (`float`, defaults to `0.0`): + The dropout probability. + attention_bias (`bool`, defaults to `False`): + Whether to use bias in the attention layer. + sample_size (`int`, defaults to `32`): + The base size of the input latent. + patch_size (`int`, defaults to `1`): + The size of the patches to use in the patch embedding layer. + norm_elementwise_affine (`bool`, defaults to `False`): + Whether to use elementwise affinity in the normalization layer. + norm_eps (`float`, defaults to `1e-6`): + The epsilon value for the normalization layer. + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["SanaTransformerBlock", "PatchEmbed"] + + @register_to_config + def __init__( + self, + in_channels: int = 32, + out_channels: Optional[int] = 32, + num_attention_heads: int = 70, + attention_head_dim: int = 32, + num_layers: int = 20, + num_cross_attention_heads: Optional[int] = 20, + cross_attention_head_dim: Optional[int] = 112, + cross_attention_dim: Optional[int] = 2240, + caption_channels: int = 2304, + mlp_ratio: float = 2.5, + dropout: float = 0.0, + attention_bias: bool = False, + sample_size: int = 32, + patch_size: int = 1, + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-6, + ) -> None: + super().__init__() + + out_channels = out_channels or in_channels + inner_dim = num_attention_heads * attention_head_dim + + # 1. Patch Embedding + self.patch_embed = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + interpolation_scale=None, + pos_embed_type=None, + ) + + # 2. Additional condition embeddings + self.time_embed = AdaLayerNormSingle(inner_dim) + + self.caption_projection = PixArtAlphaTextProjection( + in_features=caption_channels, hidden_size=inner_dim + ) + self.caption_norm = RMSNorm(inner_dim, eps=1e-5, elementwise_affine=True) + + # 3. Transformer blocks + self.transformer_blocks = nn.ModuleList( + [ + SanaTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + num_cross_attention_heads=num_cross_attention_heads, + cross_attention_head_dim=cross_attention_head_dim, + cross_attention_dim=cross_attention_dim, + attention_bias=attention_bias, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + mlp_ratio=mlp_ratio, + ) + for _ in range(num_layers) + ] + ) + + # 4. Output blocks + self.scale_shift_table = nn.Parameter( + torch.randn(2, inner_dim) / inner_dim**0.5 + ) + + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels) + + self.gradient_checkpointing = False + self.gradient_checkpointing_interval = None + + def set_gradient_checkpointing_interval(self, interval: int): + r""" + Sets the gradient checkpointing interval for the model. + + Parameters: + interval (`int`): + The interval at which to checkpoint the gradients. + """ + self.gradient_checkpointing_interval = interval + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors( + name: str, + module: torch.nn.Module, + processors: Dict[str, AttentionProcessor], + ): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor( + self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]] + ): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]: + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = ( + 1 - encoder_attention_mask.to(hidden_states.dtype) + ) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 1. Input + batch_size, num_channels, height, width = hidden_states.shape + p = self.config.patch_size + post_patch_height, post_patch_width = height // p, width // p + + hidden_states = self.patch_embed(hidden_states) + + timestep, embedded_timestep = self.time_embed( + timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view( + batch_size, -1, hidden_states.shape[-1] + ) + + encoder_hidden_states = self.caption_norm(encoder_hidden_states) + + # 2. Transformer blocks + use_reentrant = is_torch_version("<=", "1.11.0") + + def create_block_forward(block): + if ( + self.gradient_checkpointing_interval is not None + and self.gradient_checkpointing_interval > 0 + and self.gradient_checkpointing + ): + self._set_gradient_checkpointing( + block, timestep % self.gradient_checkpointing_interval == 0 + ) + if torch.is_grad_enabled() and self.gradient_checkpointing: + return lambda *inputs: torch.utils.checkpoint.checkpoint( + lambda *x: block(*x), *inputs, use_reentrant=use_reentrant + ) + else: + return block + + for block in self.transformer_blocks: + hidden_states = create_block_forward(block)( + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + timestep, + post_patch_height, + post_patch_width, + ) + + # 3. Normalization + shift, scale = ( + self.scale_shift_table[None] + + embedded_timestep[:, None].to(self.scale_shift_table.device) + ).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) + + # 4. Modulation + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + + # 5. Unpatchify + hidden_states = hidden_states.reshape( + batch_size, + post_patch_height, + post_patch_width, + self.config.patch_size, + self.config.patch_size, + -1, + ) + hidden_states = hidden_states.permute(0, 5, 1, 3, 2, 4) + output = hidden_states.reshape( + batch_size, -1, post_patch_height * p, post_patch_width * p + ) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) diff --git a/helpers/publishing/metadata.py b/helpers/publishing/metadata.py index 2bfe4104..ea7c00b6 100644 --- a/helpers/publishing/metadata.py +++ b/helpers/publishing/metadata.py @@ -493,7 +493,7 @@ def save_model_card( shortname_idx += 1 args = StateTracker.get_args() yaml_content = f"""--- -license: {licenses[model_family]} +license: {licenses.get(model_family, "other")} base_model: "{base_model}" tags: - {model_family} diff --git a/helpers/training/default_settings/safety_check.py b/helpers/training/default_settings/safety_check.py index 3d92dc5b..6be4df8f 100644 --- a/helpers/training/default_settings/safety_check.py +++ b/helpers/training/default_settings/safety_check.py @@ -141,6 +141,7 @@ def safety_check(args, accelerator): gradient_checkpointing_interval_supported_models = [ "flux", + "sana", "sdxl", ] if args.gradient_checkpointing_interval is not None: @@ -156,6 +157,3 @@ def safety_check(args, accelerator): raise ValueError( "Gradient checkpointing interval must be greater than 0. Please set it to a positive integer." ) - - if args.model_family == "sana": - args.mixed_precision == "no" diff --git a/helpers/training/diffusion_model.py b/helpers/training/diffusion_model.py index 04ea1849..950b0241 100644 --- a/helpers/training/diffusion_model.py +++ b/helpers/training/diffusion_model.py @@ -94,10 +94,6 @@ def load_diffusion_model(args, weight_dtype): subfolder=determine_subfolder(args.pretrained_transformer_subfolder), **pretrained_load_args, ) - if args.gradient_checkpointing_interval is not None: - transformer.set_gradient_checkpointing_interval( - int(args.gradient_checkpointing_interval) - ) elif args.model_family.lower() == "flux" and args.flux_attention_masked_training: from helpers.models.flux.transformer import ( FluxTransformer2DModelWithMasking, @@ -138,7 +134,7 @@ def load_diffusion_model(args, weight_dtype): if "lora" in args.model_type: raise ValueError("SmolDiT does not yet support LoRA training.") elif args.model_family == "sana": - from diffusers import SanaTransformer2DModel + from helpers.models.sana.transformer import SanaTransformer2DModel logger.info("Loading Sana flow-matching diffusion transformer..") transformer = SanaTransformer2DModel.from_pretrained( @@ -183,4 +179,18 @@ def load_diffusion_model(args, weight_dtype): set_checkpoint_interval(int(args.gradient_checkpointing_interval)) + if args.gradient_checkpointing_interval is not None: + if transformer is not None and hasattr( + transformer, "set_gradient_checkpointing_interval" + ): + logger.info("Setting gradient checkpointing interval for transformer..") + transformer.set_gradient_checkpointing_interval( + int(args.gradient_checkpointing_interval) + ) + if unet is not None and hasattr(unet, "set_gradient_checkpointing_interval"): + logger.info("Checking gradient checkpointing interval for U-Net..") + unet.set_gradient_checkpointing_interval( + int(args.gradient_checkpointing_interval) + ) + return unet, transformer diff --git a/helpers/training/trainer.py b/helpers/training/trainer.py index 8c641dd6..8131ee5c 100644 --- a/helpers/training/trainer.py +++ b/helpers/training/trainer.py @@ -466,7 +466,7 @@ def init_vae(self, move_to_accelerator: bool = True): "variant": self.config.variant, } if StateTracker.get_args().model_family == "sana": - from diffusers import DCAE as AutoencoderClass + from diffusers import AutoencoderDC as AutoencoderClass else: from diffusers import AutoencoderKL as AutoencoderClass @@ -780,11 +780,7 @@ def init_precision( return if not self.config.disable_accelerator and self.config.is_quantized: - if self.config.model_family == "sana": - # sana hurts, sana pain. sana is a special case. - self.config.base_weight_dtype = torch.float16 - self.config.enable_adamw_bf16 = False - elif self.config.base_model_default_dtype == "fp32": + if self.config.base_model_default_dtype == "fp32": self.config.base_weight_dtype = torch.float32 self.config.enable_adamw_bf16 = False elif self.config.base_model_default_dtype == "bf16": @@ -2173,7 +2169,6 @@ def model_predict( ), encoder_attention_mask=batch["encoder_attention_mask"], timestep=timesteps, - added_cond_kwargs={"resolution": None, "aspect_ratio": None}, return_dict=False, )[0] elif self.config.model_family == "pixart_sigma": @@ -2758,6 +2753,9 @@ def train(self): self._get_trainable_parameters(), self.config.max_grad_norm, ) + elif self.config.use_deepspeed_optimizer: + # deepspeed can only do norm clipping (internally) + pass elif self.config.grad_clip_method == "value": self.grad_norm = self._max_grad_value() self.accelerator.clip_grad_value_( diff --git a/helpers/training/validation.py b/helpers/training/validation.py index d5ad762c..8055e452 100644 --- a/helpers/training/validation.py +++ b/helpers/training/validation.py @@ -518,7 +518,7 @@ def init_vae(self): f"Was the VAE loaded? {precached_vae if precached_vae is None else 'Yes'}" ) if self.args.model_family == "sana": - from diffusers import DCAE as AutoencoderClass + from diffusers import AutoencoderDC as AutoencoderClass else: from diffusers import AutoencoderKL as AutoencoderClass self.vae = precached_vae @@ -1317,6 +1317,7 @@ def validate_prompt( "kolors", "flux", "sd3", + "sana", ]: extra_validation_kwargs["guidance_rescale"] = ( self.args.validation_guidance_rescale diff --git a/install/apple/poetry.lock b/install/apple/poetry.lock index c09f3fc5..5bf30d99 100644 --- a/install/apple/poetry.lock +++ b/install/apple/poetry.lock @@ -575,7 +575,7 @@ training = ["Jinja2", "accelerate (>=0.31.0)", "datasets", "peft (>=0.6.0)", "pr type = "git" url = "https://github.com/lawrence-cj/diffusers" reference = "Sana" -resolved_reference = "bd13e36902d0be89aa3111bab5f7fb5c1dbb6646" +resolved_reference = "b4af50d67f83a893420496ad1382186df8e91688" [[package]] name = "dill" diff --git a/install/rocm/poetry.lock b/install/rocm/poetry.lock index fa36c91f..f448796d 100644 --- a/install/rocm/poetry.lock +++ b/install/rocm/poetry.lock @@ -549,19 +549,17 @@ triton = ["triton (==2.1.0)"] [[package]] name = "diffusers" -version = "0.31.0" +version = "0.32.0.dev0" description = "State-of-the-art diffusion in PyTorch and JAX." optional = false python-versions = ">=3.8.0" -files = [ - {file = "diffusers-0.31.0-py3-none-any.whl", hash = "sha256:cbc498ae63f4abfc7c3a07649cdcbee229ef2f9a9a1f0d19c9bbaf22f8d30c1f"}, - {file = "diffusers-0.31.0.tar.gz", hash = "sha256:b1d01a73e45d43a0630c299173915dddd69fc50f2ae8f2ab5de4fd245eaed72f"}, -] +files = [] +develop = false [package.dependencies] filelock = "*" huggingface-hub = ">=0.23.2" -importlib-metadata = "*" +importlib_metadata = "*" numpy = "*" Pillow = "*" regex = "!=2019.12.17" @@ -569,14 +567,20 @@ requests = "*" safetensors = ">=0.3.1" [package.extras] -dev = ["GitPython (<3.1.19)", "Jinja2", "accelerate (>=0.31.0)", "compel (==0.1.8)", "datasets", "flax (>=0.4.1)", "hf-doc-builder (>=0.3.0)", "invisible-watermark (>=0.2.0)", "isort (>=5.5.4)", "jax (>=0.4.1)", "jaxlib (>=0.4.1)", "k-diffusion (>=0.0.12)", "librosa", "parameterized", "peft (>=0.6.0)", "protobuf (>=3.20.3,<4)", "pytest", "pytest-timeout", "pytest-xdist", "requests-mock (==1.10.0)", "ruff (==0.1.5)", "safetensors (>=0.3.1)", "scipy", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "torch (>=1.4,<2.5.0)", "torchvision", "transformers (>=4.41.2)", "urllib3 (<=2.0.0)"] +dev = ["GitPython (<3.1.19)", "Jinja2", "Jinja2", "accelerate (>=0.31.0)", "accelerate (>=0.31.0)", "compel (==0.1.8)", "datasets", "datasets", "flax (>=0.4.1)", "hf-doc-builder (>=0.3.0)", "hf-doc-builder (>=0.3.0)", "invisible-watermark (>=0.2.0)", "isort (>=5.5.4)", "jax (>=0.4.1)", "jaxlib (>=0.4.1)", "k-diffusion (>=0.0.12)", "librosa", "parameterized", "peft (>=0.6.0)", "protobuf (>=3.20.3,<4)", "pytest", "pytest-timeout", "pytest-xdist", "requests-mock (==1.10.0)", "ruff (==0.1.5)", "safetensors (>=0.3.1)", "scipy", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "torch (>=1.4)", "torchvision", "transformers (>=4.41.2)", "urllib3 (<=2.0.0)"] docs = ["hf-doc-builder (>=0.3.0)"] flax = ["flax (>=0.4.1)", "jax (>=0.4.1)", "jaxlib (>=0.4.1)"] quality = ["hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "ruff (==0.1.5)", "urllib3 (<=2.0.0)"] test = ["GitPython (<3.1.19)", "Jinja2", "compel (==0.1.8)", "datasets", "invisible-watermark (>=0.2.0)", "k-diffusion (>=0.0.12)", "librosa", "parameterized", "pytest", "pytest-timeout", "pytest-xdist", "requests-mock (==1.10.0)", "safetensors (>=0.3.1)", "scipy", "sentencepiece (>=0.1.91,!=0.1.92)", "torchvision", "transformers (>=4.41.2)"] -torch = ["accelerate (>=0.31.0)", "torch (>=1.4,<2.5.0)"] +torch = ["accelerate (>=0.31.0)", "torch (>=1.4)"] training = ["Jinja2", "accelerate (>=0.31.0)", "datasets", "peft (>=0.6.0)", "protobuf (>=3.20.3,<4)", "tensorboard"] +[package.source] +type = "git" +url = "https://github.com/lawrence-cj/diffusers" +reference = "Sana" +resolved_reference = "b4af50d67f83a893420496ad1382186df8e91688" + [[package]] name = "dill" version = "0.3.8" @@ -3948,4 +3952,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "515537f9815456e8d5dc224c21f81b61cdf574a3cdeb74d75aa2a13ca1d9edcd" +content-hash = "9dc595f310d5a8ee64f6d6421d0d3070bd8440e54e7b5b1630b5eab7b7f3672e" diff --git a/install/rocm/pyproject.toml b/install/rocm/pyproject.toml index e92b38af..d7707973 100644 --- a/install/rocm/pyproject.toml +++ b/install/rocm/pyproject.toml @@ -22,7 +22,7 @@ colorama = "^0.4.6" compel = "^2" datasets = "^3.0.0" deepspeed = "^0.15.1" -diffusers = "^0.31.0" +diffusers = {git = "https://github.com/lawrence-cj/diffusers", rev = "Sana"} iterutils = "^0.1.6" numpy = "1.26" open-clip-torch = "^2.26.1" diff --git a/poetry.lock b/poetry.lock index 7b4dc9ab..7a266c75 100644 --- a/poetry.lock +++ b/poetry.lock @@ -720,7 +720,7 @@ training = ["Jinja2", "accelerate (>=0.31.0)", "datasets", "peft (>=0.6.0)", "pr type = "git" url = "https://github.com/lawrence-cj/diffusers" reference = "Sana" -resolved_reference = "bd13e36902d0be89aa3111bab5f7fb5c1dbb6646" +resolved_reference = "d3312ccec73ff753338792a3e8dc8fc39168ce49" [[package]] name = "dill" From e13575113d483ef66f838e66aae4169db0527cc3 Mon Sep 17 00:00:00 2001 From: bghira Date: Fri, 13 Dec 2024 06:36:12 -0600 Subject: [PATCH 2/4] use terminusresearch source for sana 1.6b weights by default so that we use bf16 without variant shenanagins --- configure.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configure.py b/configure.py index e11f610c..61c96875 100644 --- a/configure.py +++ b/configure.py @@ -36,7 +36,7 @@ "terminus": "ptx0/terminus-xl-velocity-v2", "sd3": "stabilityai/stable-diffusion-3.5-large", "legacy": "stabilityai/stable-diffusion-2-1-base", - "sana": "Efficient-Large-Model/Sana_1600M_1024px_diffusers", + "sana": "terminusresearch/sana-1.6b-1024px", } default_cfg = { From e3791b924481701887f72d9df9fa41b260d0d343 Mon Sep 17 00:00:00 2001 From: bghira Date: Fri, 13 Dec 2024 06:52:29 -0600 Subject: [PATCH 3/4] sana: final pipeline export for full tune --- helpers/training/trainer.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/helpers/training/trainer.py b/helpers/training/trainer.py index 8131ee5c..577ccd82 100644 --- a/helpers/training/trainer.py +++ b/helpers/training/trainer.py @@ -3338,6 +3338,27 @@ def train(self): scheduler=None, ) + elif self.config.model_family == "sana": + from diffusers import SanaPipeline + + self.pipeline = SanaPipeline.from_pretrained( + self.config.pretrained_model_name_or_path, + text_encoder=self.text_encoder_1 + or ( + self.text_encoder_cls_1.from_pretrained( + self.config.pretrained_model_name_or_path, + subfolder="text_encoder", + revision=self.config.revision, + variant=self.config.variant, + ) + if self.config.save_text_encoder + else None + ), + tokenizer=self.tokenizer_1, + vae=self.vae, + transformer=self.transformer, + ) + else: sdxl_pipeline_cls = StableDiffusionXLPipeline if self.config.model_family == "kolors": From 82aacec453244034285b40bb120f67e49fc0a694 Mon Sep 17 00:00:00 2001 From: bghira Date: Fri, 13 Dec 2024 07:01:53 -0600 Subject: [PATCH 4/4] sana: update quickstart to mention bf16 weights --- documentation/quickstart/SANA.md | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/documentation/quickstart/SANA.md b/documentation/quickstart/SANA.md index 964f4d76..2dd355b9 100644 --- a/documentation/quickstart/SANA.md +++ b/documentation/quickstart/SANA.md @@ -12,16 +12,19 @@ Sana is very lightweight and might not even need full gradient checkpointing ena Sana is a strange architecture relative to other models that are trainable by SimpleTuner; -- It requires FP16 training, unlike other models, this **will not work** with BF16 -- It will not be happy with model quantisation due to the need to run in FP16; most quantisation methods require the use of BF16 - - NF4 looks like it might work, but hasn't been fully tested -- SageAttention does not work with Sana due to the shapes inside the model +- Initially, unlike other models, Sana required fp16 training and would crash out with bf16 + - Model authors at NVIDIA were gracious enough to follow-up with bf16-compatible weights for fine-tuning +- Quantisation might be more sensitive on this model family due to the issues with bf16/fp16 +- SageAttention does not work with Sana (yet) due to its head_dim shape that is currently unsupported - The loss value when training Sana is very high, and it might need a much lower learning rate than other models (eg. `1e-5` or thereabouts) +- Training might hit NaN values, and it's not clear why this happens Gradient checkpointing can free VRAM, but slows down training. A chart of test results from a 4090 with 5800X3D: ![image](https://github.com/user-attachments/assets/310bf099-a077-4378-acf4-f60b4b82fdc4) +SimpleTuner's Sana modeling code allows the specification of `--gradient_checkpointing_interval` to checkpoint every _n_ blocks and attain the results seen in the above chart. + ### Prerequisites Make sure that you have python installed; SimpleTuner does well with 3.10 or 3.11. **Python 3.12 should not be used**. @@ -117,8 +120,9 @@ There, you will possibly need to modify the following variables: - `validation_num_inference_steps` - Use somewhere around 50 for the best quality, though you can accept less if you're happy with the results. - `use_ema` - setting this to `true` will greatly help obtain a more smoothed result alongside your main trained checkpoint. -- `optimizer` - Since Sana requires fp16, some optimisers like `adamw_bf16` will not work with it. You can use `optimi-lion`, `optimi-stableadamw` or others you are familiar with instead. -- `mixed_precision` - This gets overridden to `no` for you anyway, since we rely on fp16 training. +- `optimizer` - You can use any optimiser you are comfortable and familiar with, but we will use `optimi-adamw` for this example. +- `mixed_precision` - It's recommended to set this to `bf16` for the most efficient training configuration, or `no` (but will consume more memory and be slower). + - A value of `fp16` is not recommended here but may be required for certain Sana finetunes (and introduces other new issues to enable this) - `gradient_checkpointing` - Disabling this will go the fastest, but limits your batch sizes. It is required to enable this to get the lowest VRAM usage. - `gradient_checkpointing_interval` - If `gradient_checkpointing` feels like overkill on your GPU, you could set this to a value of 2 or higher to only checkpoint every _n_ blocks. A value of 2 would checkpoint half of the blocks, and 3 would be one-third.