Skip to content

Commit

Permalink
Merge pull request #1184 from bghira/feature/gradient-checkpointing-s…
Browse files Browse the repository at this point in the history
…peedup

gradient checkpointing speed-up
  • Loading branch information
bghira authored Dec 3, 2024
2 parents 9cd974e + 96d477e commit cd0644d
Show file tree
Hide file tree
Showing 8 changed files with 183 additions and 5 deletions.
30 changes: 29 additions & 1 deletion configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,13 +433,29 @@ def configure_env():
env_contents["--attention_mechanism"] = "diffusers"
use_sageattention = (
prompt_user(
"Would you like to use SageAttention? This is an experimental option that can greatly speed up training. (y/[n])",
"Would you like to use SageAttention for image validation generation? (y/[n])",
"n",
).lower()
== "y"
)
if use_sageattention:
env_contents["--attention_mechanism"] = "sageattention"
env_contents["--sageattention_usage"] = "inference"
use_sageattention_training = (
prompt_user(
(
"Would you like to use SageAttention to cover the forward and backward pass during training?"
" This has the undesirable consequence of leaving the attention layers untrained,"
" as SageAttention lacks the capability to fully track gradients through quantisation."
" If you are not training the attention layers for some reason, this may not matter and"
" you can safely enable this. For all other use-cases, reconsideration and caution are warranted."
),
"n",
).lower()
== "y"
)
if use_sageattention_training:
env_contents["--sageattention_usage"] = "both"

# properly disable wandb/tensorboard/comet_ml etc by default
report_to_str = "none"
Expand Down Expand Up @@ -527,6 +543,18 @@ def configure_env():
)
)
env_contents["--gradient_checkpointing"] = "true"
gradient_checkpointing_interval = prompt_user(
"Would you like to configure a gradient checkpointing interval? A value larger than 1 will increase VRAM usage but speed up training by skipping checkpoint creation every Nth layer, and a zero will disable this feature.",
0,
)
try:
if int(gradient_checkpointing_interval) > 1:
env_contents["--gradient_checkpointing_interval"] = int(
gradient_checkpointing_interval
)
except:
print("Could not parse gradient checkpointing interval. Not enabling.")
pass

env_contents["--caption_dropout_probability"] = float(
prompt_user(
Expand Down
11 changes: 10 additions & 1 deletion helpers/configuration/cmd_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1054,6 +1054,15 @@ def get_argument_parser():
action="store_true",
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
)
parser.add_argument(
"--gradient_checkpointing_interval",
default=None,
type=int,
help=(
"Some models (Flux, SDXL, SD1.x/2.x) can have their gradient checkpointing limited to every nth block."
" This can speed up training but will use more memory with larger intervals."
),
)
parser.add_argument(
"--learning_rate",
type=float,
Expand Down Expand Up @@ -1744,7 +1753,7 @@ def get_argument_parser():
parser.add_argument(
"--enable_xformers_memory_efficient_attention",
action="store_true",
help="Whether or not to use xformers. Deprecated and slated for future removal.",
help="Whether or not to use xformers. Deprecated and slated for future removal. Use --attention_mechanism.",
)
parser.add_argument(
"--set_grads_to_none",
Expand Down
23 changes: 21 additions & 2 deletions helpers/models/flux/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,11 +489,16 @@ def __init__(
)

self.gradient_checkpointing = False
# added for users to disable checkpointing every nth step
self.gradient_checkpointing_interval = None

def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value

def set_gradient_checkpointing_interval(self, value: int):
self.gradient_checkpointing_interval = value

def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -574,7 +579,14 @@ def forward(
image_rotary_emb = self.pos_embed(ids)

for index_block, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
if (
self.training
and self.gradient_checkpointing
and (
self.gradient_checkpointing_interval is None
or index_block % self.gradient_checkpointing_interval == 0
)
):

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down Expand Up @@ -614,7 +626,14 @@ def custom_forward(*inputs):
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)

for index_block, block in enumerate(self.single_transformer_blocks):
if self.training and self.gradient_checkpointing:
if (
self.training
and self.gradient_checkpointing
or (
self.gradient_checkpointing_interval is not None
and index_block % self.gradient_checkpointing_interval == 0
)
):

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down
18 changes: 18 additions & 0 deletions helpers/training/default_settings/safety_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,21 @@ def safety_check(args, accelerator):
f"{args.base_model_precision} is not supported with SageAttention. Please select from int8 or fp8, or, disable quantisation to use SageAttention."
)
sys.exit(1)

gradient_checkpointing_interval_supported_models = [
"flux",
"sdxl",
]
if args.gradient_checkpointing_interval is not None:
if (
args.model_family.lower()
not in gradient_checkpointing_interval_supported_models
):
logger.error(
f"Gradient checkpointing is not supported with {args.model_family} models. Please disable --gradient_checkpointing_interval by setting it to None, or remove it from your configuration. Currently supported models: {gradient_checkpointing_interval_supported_models}"
)
sys.exit(1)
if args.gradient_checkpointing_interval == 0:
raise ValueError(
"Gradient checkpointing interval must be greater than 0. Please set it to a positive integer."
)
29 changes: 28 additions & 1 deletion helpers/training/diffusion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def load_diffusion_model(args, weight_dtype):
elif (
args.model_family.lower() == "flux" and not args.flux_attention_masked_training
):
from diffusers.models import FluxTransformer2DModel
from helpers.models.flux.transformer import (
FluxTransformer2DModelWithMasking as FluxTransformer2DModel,
)
import torch

if torch.cuda.is_available():
Expand Down Expand Up @@ -92,6 +94,10 @@ def load_diffusion_model(args, weight_dtype):
subfolder=determine_subfolder(args.pretrained_transformer_subfolder),
**pretrained_load_args,
)
if args.gradient_checkpointing_interval is not None:
transformer.set_gradient_checkpointing_interval(
int(args.gradient_checkpointing_interval)
)
elif args.model_family.lower() == "flux" and args.flux_attention_masked_training:
from helpers.models.flux.transformer import (
FluxTransformer2DModelWithMasking,
Expand All @@ -103,6 +109,10 @@ def load_diffusion_model(args, weight_dtype):
subfolder=determine_subfolder(args.pretrained_transformer_subfolder),
**pretrained_load_args,
)
if args.gradient_checkpointing_interval is not None:
transformer.set_gradient_checkpointing_interval(
int(args.gradient_checkpointing_interval)
)
elif args.model_family == "pixart_sigma":
from diffusers.models import PixArtTransformer2DModel

Expand Down Expand Up @@ -145,5 +155,22 @@ def load_diffusion_model(args, weight_dtype):
subfolder=determine_subfolder(args.pretrained_unet_subfolder),
**pretrained_load_args,
)
if (
args.gradient_checkpointing_interval is not None
and args.gradient_checkpointing_interval > 0
):
logger.warning(
"Using experimental gradient checkpointing monkeypatch for a checkpoint interval of {}".format(
args.gradient_checkpointing_interval
)
)
# monkey-patch the gradient checkpointing function for pytorch to run every nth call only.
# definitely one of the more awful things I've ever done while programming, but it's easier than
# modifying every one of the unet blocks' forward calls in Diffusers to make it work properly.
from helpers.training.gradient_checkpointing_interval import (
set_checkpoint_interval,
)

set_checkpoint_interval(int(args.gradient_checkpointing_interval))

return unet, transformer
42 changes: 42 additions & 0 deletions helpers/training/gradient_checkpointing_interval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import torch
from torch.utils.checkpoint import checkpoint as original_checkpoint


# Global variables to keep track of the checkpointing state
_checkpoint_call_count = 0
_checkpoint_interval = 4 # You can set this to any interval you prefer


def reset_checkpoint_counter():
"""Resets the checkpoint call counter. Call this at the beginning of the forward pass."""
global _checkpoint_call_count
_checkpoint_call_count = 0


def set_checkpoint_interval(n):
"""Sets the interval at which checkpointing is skipped."""
global _checkpoint_interval
_checkpoint_interval = n


def checkpoint_wrapper(function, *args, use_reentrant=True, **kwargs):
"""Wrapper function for torch.utils.checkpoint.checkpoint."""
global _checkpoint_call_count, _checkpoint_interval
_checkpoint_call_count += 1

if (
_checkpoint_interval > 0
and (_checkpoint_call_count % _checkpoint_interval) == 0
):
# Use the original checkpoint function
return original_checkpoint(
function, *args, use_reentrant=use_reentrant, **kwargs
)
else:
# Skip checkpointing: execute the function directly
# Do not pass 'use_reentrant' to the function
return function(*args, **kwargs)


# Monkeypatch torch.utils.checkpoint.checkpoint
torch.utils.checkpoint.checkpoint = checkpoint_wrapper
34 changes: 34 additions & 0 deletions helpers/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -997,8 +997,11 @@ def init_post_load_freeze(self):
self.transformer = apply_bitfit_freezing(
unwrap_model(self.accelerator, self.transformer), self.config
)
self.enable_gradient_checkpointing()

def enable_gradient_checkpointing(self):
if self.config.gradient_checkpointing:
logger.info("Enabling gradient checkpointing.")
if self.unet is not None:
unwrap_model(
self.accelerator, self.unet
Expand All @@ -1022,6 +1025,32 @@ def init_post_load_freeze(self):
self.accelerator, self.text_encoder_2
).gradient_checkpointing_enable()

def disable_gradient_checkpointing(self):
if self.config.gradient_checkpointing:
logger.info("Disabling gradient checkpointing.")
if self.unet is not None:
unwrap_model(
self.accelerator, self.unet
).disable_gradient_checkpointing()
if self.transformer is not None and self.config.model_family != "smoldit":
unwrap_model(
self.accelerator, self.transformer
).disable_gradient_checkpointing()
if self.config.controlnet:
unwrap_model(
self.accelerator, self.controlnet
).disable_gradient_checkpointing()
if (
hasattr(self.config, "train_text_encoder")
and self.config.train_text_encoder
):
unwrap_model(
self.accelerator, self.text_encoder_1
).gradient_checkpointing_disable()
unwrap_model(
self.accelerator, self.text_encoder_2
).gradient_checkpointing_disable()

def _get_trainable_parameters(self):
# Return just a list of the currently trainable parameters.
if self.config.model_type == "lora":
Expand Down Expand Up @@ -2188,8 +2217,10 @@ def train(self):
# normal run-of-the-mill validation on startup.
if self.validation is not None:
self.enable_sageattention_inference()
self.disable_gradient_checkpointing()
self.validation.run_validations(validation_type="base_model", step=0)
self.disable_sageattention_inference()
self.enable_gradient_checkpointing()

self.mark_optimizer_train()

Expand Down Expand Up @@ -2913,11 +2944,13 @@ def train(self):
if self.validation is not None:
if self.validation.would_validate():
self.enable_sageattention_inference()
self.disable_gradient_checkpointing()
self.validation.run_validations(
validation_type="intermediary", step=step
)
if self.validation.would_validate():
self.disable_sageattention_inference()
self.enable_gradient_checkpointing()
self.mark_optimizer_train()
if (
self.config.push_to_hub
Expand Down Expand Up @@ -2967,6 +3000,7 @@ def train(self):
self.mark_optimizer_eval()
if self.validation is not None:
self.enable_sageattention_inference()
self.disable_gradient_checkpointing()
validation_images = self.validation.run_validations(
validation_type="final",
step=self.state["global_step"],
Expand Down
1 change: 1 addition & 0 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def test_stats_memory_used_none(
flux_schedule_shift=3,
flux_schedule_auto_shift=False,
validation_guidance_skip_layers=None,
gradient_checkpointing_interval=None,
),
)
def test_misc_init(
Expand Down

0 comments on commit cd0644d

Please sign in to comment.