diff --git a/configure.py b/configure.py index b5c26d28..26620662 100644 --- a/configure.py +++ b/configure.py @@ -433,13 +433,29 @@ def configure_env(): env_contents["--attention_mechanism"] = "diffusers" use_sageattention = ( prompt_user( - "Would you like to use SageAttention? This is an experimental option that can greatly speed up training. (y/[n])", + "Would you like to use SageAttention for image validation generation? (y/[n])", "n", ).lower() == "y" ) if use_sageattention: env_contents["--attention_mechanism"] = "sageattention" + env_contents["--sageattention_usage"] = "inference" + use_sageattention_training = ( + prompt_user( + ( + "Would you like to use SageAttention to cover the forward and backward pass during training?" + " This has the undesirable consequence of leaving the attention layers untrained," + " as SageAttention lacks the capability to fully track gradients through quantisation." + " If you are not training the attention layers for some reason, this may not matter and" + " you can safely enable this. For all other use-cases, reconsideration and caution are warranted." + ), + "n", + ).lower() + == "y" + ) + if use_sageattention_training: + env_contents["--sageattention_usage"] = "both" # properly disable wandb/tensorboard/comet_ml etc by default report_to_str = "none" @@ -527,6 +543,18 @@ def configure_env(): ) ) env_contents["--gradient_checkpointing"] = "true" + gradient_checkpointing_interval = prompt_user( + "Would you like to configure a gradient checkpointing interval? A value larger than 1 will increase VRAM usage but speed up training by skipping checkpoint creation every Nth layer, and a zero will disable this feature.", + 0, + ) + try: + if int(gradient_checkpointing_interval) > 1: + env_contents["--gradient_checkpointing_interval"] = int( + gradient_checkpointing_interval + ) + except: + print("Could not parse gradient checkpointing interval. Not enabling.") + pass env_contents["--caption_dropout_probability"] = float( prompt_user( diff --git a/helpers/configuration/cmd_args.py b/helpers/configuration/cmd_args.py index c90923e5..457e1118 100644 --- a/helpers/configuration/cmd_args.py +++ b/helpers/configuration/cmd_args.py @@ -1054,6 +1054,15 @@ def get_argument_parser(): action="store_true", help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", ) + parser.add_argument( + "--gradient_checkpointing_interval", + default=None, + type=int, + help=( + "Some models (Flux, SDXL, SD1.x/2.x) can have their gradient checkpointing limited to every nth block." + " This can speed up training but will use more memory with larger intervals." + ), + ) parser.add_argument( "--learning_rate", type=float, @@ -1744,7 +1753,7 @@ def get_argument_parser(): parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", - help="Whether or not to use xformers. Deprecated and slated for future removal.", + help="Whether or not to use xformers. Deprecated and slated for future removal. Use --attention_mechanism.", ) parser.add_argument( "--set_grads_to_none", diff --git a/helpers/models/flux/transformer.py b/helpers/models/flux/transformer.py index 7a9c80f2..77097648 100644 --- a/helpers/models/flux/transformer.py +++ b/helpers/models/flux/transformer.py @@ -489,11 +489,16 @@ def __init__( ) self.gradient_checkpointing = False + # added for users to disable checkpointing every nth step + self.gradient_checkpointing_interval = None def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value + def set_gradient_checkpointing_interval(self, value: int): + self.gradient_checkpointing_interval = value + def forward( self, hidden_states: torch.Tensor, @@ -574,7 +579,14 @@ def forward( image_rotary_emb = self.pos_embed(ids) for index_block, block in enumerate(self.transformer_blocks): - if self.training and self.gradient_checkpointing: + if ( + self.training + and self.gradient_checkpointing + and ( + self.gradient_checkpointing_interval is None + or index_block % self.gradient_checkpointing_interval == 0 + ) + ): def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): @@ -614,7 +626,14 @@ def custom_forward(*inputs): hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) for index_block, block in enumerate(self.single_transformer_blocks): - if self.training and self.gradient_checkpointing: + if ( + self.training + and self.gradient_checkpointing + or ( + self.gradient_checkpointing_interval is not None + and index_block % self.gradient_checkpointing_interval == 0 + ) + ): def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): diff --git a/helpers/training/default_settings/safety_check.py b/helpers/training/default_settings/safety_check.py index 36f321c5..e6522a40 100644 --- a/helpers/training/default_settings/safety_check.py +++ b/helpers/training/default_settings/safety_check.py @@ -133,3 +133,21 @@ def safety_check(args, accelerator): f"{args.base_model_precision} is not supported with SageAttention. Please select from int8 or fp8, or, disable quantisation to use SageAttention." ) sys.exit(1) + + gradient_checkpointing_interval_supported_models = [ + "flux", + "sdxl", + ] + if args.gradient_checkpointing_interval is not None: + if ( + args.model_family.lower() + not in gradient_checkpointing_interval_supported_models + ): + logger.error( + f"Gradient checkpointing is not supported with {args.model_family} models. Please disable --gradient_checkpointing_interval by setting it to None, or remove it from your configuration. Currently supported models: {gradient_checkpointing_interval_supported_models}" + ) + sys.exit(1) + if args.gradient_checkpointing_interval == 0: + raise ValueError( + "Gradient checkpointing interval must be greater than 0. Please set it to a positive integer." + ) diff --git a/helpers/training/diffusion_model.py b/helpers/training/diffusion_model.py index 5ac0c207..290858bc 100644 --- a/helpers/training/diffusion_model.py +++ b/helpers/training/diffusion_model.py @@ -52,7 +52,9 @@ def load_diffusion_model(args, weight_dtype): elif ( args.model_family.lower() == "flux" and not args.flux_attention_masked_training ): - from diffusers.models import FluxTransformer2DModel + from helpers.models.flux.transformer import ( + FluxTransformer2DModelWithMasking as FluxTransformer2DModel, + ) import torch if torch.cuda.is_available(): @@ -92,6 +94,10 @@ def load_diffusion_model(args, weight_dtype): subfolder=determine_subfolder(args.pretrained_transformer_subfolder), **pretrained_load_args, ) + if args.gradient_checkpointing_interval is not None: + transformer.set_gradient_checkpointing_interval( + int(args.gradient_checkpointing_interval) + ) elif args.model_family.lower() == "flux" and args.flux_attention_masked_training: from helpers.models.flux.transformer import ( FluxTransformer2DModelWithMasking, @@ -103,6 +109,10 @@ def load_diffusion_model(args, weight_dtype): subfolder=determine_subfolder(args.pretrained_transformer_subfolder), **pretrained_load_args, ) + if args.gradient_checkpointing_interval is not None: + transformer.set_gradient_checkpointing_interval( + int(args.gradient_checkpointing_interval) + ) elif args.model_family == "pixart_sigma": from diffusers.models import PixArtTransformer2DModel @@ -145,5 +155,22 @@ def load_diffusion_model(args, weight_dtype): subfolder=determine_subfolder(args.pretrained_unet_subfolder), **pretrained_load_args, ) + if ( + args.gradient_checkpointing_interval is not None + and args.gradient_checkpointing_interval > 0 + ): + logger.warning( + "Using experimental gradient checkpointing monkeypatch for a checkpoint interval of {}".format( + args.gradient_checkpointing_interval + ) + ) + # monkey-patch the gradient checkpointing function for pytorch to run every nth call only. + # definitely one of the more awful things I've ever done while programming, but it's easier than + # modifying every one of the unet blocks' forward calls in Diffusers to make it work properly. + from helpers.training.gradient_checkpointing_interval import ( + set_checkpoint_interval, + ) + + set_checkpoint_interval(int(args.gradient_checkpointing_interval)) return unet, transformer diff --git a/helpers/training/gradient_checkpointing_interval.py b/helpers/training/gradient_checkpointing_interval.py new file mode 100644 index 00000000..18026044 --- /dev/null +++ b/helpers/training/gradient_checkpointing_interval.py @@ -0,0 +1,42 @@ +import torch +from torch.utils.checkpoint import checkpoint as original_checkpoint + + +# Global variables to keep track of the checkpointing state +_checkpoint_call_count = 0 +_checkpoint_interval = 4 # You can set this to any interval you prefer + + +def reset_checkpoint_counter(): + """Resets the checkpoint call counter. Call this at the beginning of the forward pass.""" + global _checkpoint_call_count + _checkpoint_call_count = 0 + + +def set_checkpoint_interval(n): + """Sets the interval at which checkpointing is skipped.""" + global _checkpoint_interval + _checkpoint_interval = n + + +def checkpoint_wrapper(function, *args, use_reentrant=True, **kwargs): + """Wrapper function for torch.utils.checkpoint.checkpoint.""" + global _checkpoint_call_count, _checkpoint_interval + _checkpoint_call_count += 1 + + if ( + _checkpoint_interval > 0 + and (_checkpoint_call_count % _checkpoint_interval) == 0 + ): + # Use the original checkpoint function + return original_checkpoint( + function, *args, use_reentrant=use_reentrant, **kwargs + ) + else: + # Skip checkpointing: execute the function directly + # Do not pass 'use_reentrant' to the function + return function(*args, **kwargs) + + +# Monkeypatch torch.utils.checkpoint.checkpoint +torch.utils.checkpoint.checkpoint = checkpoint_wrapper diff --git a/helpers/training/trainer.py b/helpers/training/trainer.py index dfe7593d..f6a15696 100644 --- a/helpers/training/trainer.py +++ b/helpers/training/trainer.py @@ -997,8 +997,11 @@ def init_post_load_freeze(self): self.transformer = apply_bitfit_freezing( unwrap_model(self.accelerator, self.transformer), self.config ) + self.enable_gradient_checkpointing() + def enable_gradient_checkpointing(self): if self.config.gradient_checkpointing: + logger.info("Enabling gradient checkpointing.") if self.unet is not None: unwrap_model( self.accelerator, self.unet @@ -1022,6 +1025,32 @@ def init_post_load_freeze(self): self.accelerator, self.text_encoder_2 ).gradient_checkpointing_enable() + def disable_gradient_checkpointing(self): + if self.config.gradient_checkpointing: + logger.info("Disabling gradient checkpointing.") + if self.unet is not None: + unwrap_model( + self.accelerator, self.unet + ).disable_gradient_checkpointing() + if self.transformer is not None and self.config.model_family != "smoldit": + unwrap_model( + self.accelerator, self.transformer + ).disable_gradient_checkpointing() + if self.config.controlnet: + unwrap_model( + self.accelerator, self.controlnet + ).disable_gradient_checkpointing() + if ( + hasattr(self.config, "train_text_encoder") + and self.config.train_text_encoder + ): + unwrap_model( + self.accelerator, self.text_encoder_1 + ).gradient_checkpointing_disable() + unwrap_model( + self.accelerator, self.text_encoder_2 + ).gradient_checkpointing_disable() + def _get_trainable_parameters(self): # Return just a list of the currently trainable parameters. if self.config.model_type == "lora": @@ -2188,8 +2217,10 @@ def train(self): # normal run-of-the-mill validation on startup. if self.validation is not None: self.enable_sageattention_inference() + self.disable_gradient_checkpointing() self.validation.run_validations(validation_type="base_model", step=0) self.disable_sageattention_inference() + self.enable_gradient_checkpointing() self.mark_optimizer_train() @@ -2913,11 +2944,13 @@ def train(self): if self.validation is not None: if self.validation.would_validate(): self.enable_sageattention_inference() + self.disable_gradient_checkpointing() self.validation.run_validations( validation_type="intermediary", step=step ) if self.validation.would_validate(): self.disable_sageattention_inference() + self.enable_gradient_checkpointing() self.mark_optimizer_train() if ( self.config.push_to_hub @@ -2967,6 +3000,7 @@ def train(self): self.mark_optimizer_eval() if self.validation is not None: self.enable_sageattention_inference() + self.disable_gradient_checkpointing() validation_images = self.validation.run_validations( validation_type="final", step=self.state["global_step"], diff --git a/tests/test_trainer.py b/tests/test_trainer.py index c54789d6..3a036762 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -140,6 +140,7 @@ def test_stats_memory_used_none( flux_schedule_shift=3, flux_schedule_auto_shift=False, validation_guidance_skip_layers=None, + gradient_checkpointing_interval=None, ), ) def test_misc_init(