Skip to content

Latest commit

 

History

History
667 lines (523 loc) · 37.4 KB

README.md

File metadata and controls

667 lines (523 loc) · 37.4 KB

finetrainers 🧪

cogvideox-factory was renamed to finetrainers. If you're looking to train CogVideoX or Mochi with the legacy training scripts, please refer to this README instead. Everything in the training/ directory will be eventually moved and supported under finetrainers.

FineTrainers is a work-in-progress library to support training of video models. The first priority is to support lora training for all models in Diffusers, and eventually other methods like controlnets, control-loras, distillation, etc.

CogVideoX-LoRA.mp4

News

  • 🔥 2024-12-20: Support for LoRA finetuning of Hunyuan Video added! We would like to thank @SHYuanBest for his work on a training script here.
  • 🔥 2024-12-18: Support for LoRA finetuning of LTX Video added!

Quickstart

Clone the repository and make sure the requirements are installed: pip install -r requirements.txt and install diffusers from source by pip install git+https://github.com/huggingface/diffusers.

Then download a dataset:

# install `huggingface_hub`
huggingface-cli download \
  --repo-type dataset Wild-Heart/Disney-VideoGeneration-Dataset \
  --local-dir video-dataset-disney

Then launch LoRA fine-tuning. For CogVideoX and Mochi, refer to this and this. For details on the available arguments for the training scripts, see args.md.

Note: It is recommended to use Pytorch 2.5.1 or above for training. Previous versions can lead to completely black videos, OOM errors, or other issues and are not tested.

LTX Video

Training:

#!/bin/bash

# export TORCH_LOGS="+dynamo,recompiles,graph_breaks"
# export TORCHDYNAMO_VERBOSE=1
export WANDB_MODE="offline"
export NCCL_P2P_DISABLE=1
export TORCH_NCCL_ENABLE_MONITORING=0
export FINETRAINERS_LOG_LEVEL=DEBUG

GPU_IDS="0,1"

DATA_ROOT="/raid/aryan/video-dataset-disney"
CAPTION_COLUMN="prompts.txt"
VIDEO_COLUMN="videos.txt"
OUTPUT_DIR="/path/to/output/directory/ltx-video/ltxv_disney"

ID_TOKEN="BW_STYLE"

# Model arguments
model_cmd="--model_name ltx_video \
  --pretrained_model_name_or_path Lightricks/LTX-Video"

# Dataset arguments
dataset_cmd="--data_root $DATA_ROOT \
  --video_column $VIDEO_COLUMN \
  --caption_column $CAPTION_COLUMN \
  --id_token $ID_TOKEN \
  --video_resolution_buckets 49x512x768 \
  --caption_dropout_p 0.05"

# Dataloader arguments
dataloader_cmd="--dataloader_num_workers 0"

# Diffusion arguments
diffusion_cmd="--flow_resolution_shifting"

# Training arguments
training_cmd="--training_type lora \
  --seed 42 \
  --mixed_precision bf16 \
  --batch_size 1 \
  --train_steps 1200 \
  --rank 128 \
  --lora_alpha 128 \
  --target_modules to_q to_k to_v to_out.0 \
  --gradient_accumulation_steps 1 \
  --gradient_checkpointing \
  --checkpointing_steps 500 \
  --checkpointing_limit 2 \
  --enable_slicing \
  --enable_tiling"

# Optimizer arguments
optimizer_cmd="--optimizer adamw \
  --lr 3e-5 \
  --lr_scheduler constant_with_warmup \
  --lr_warmup_steps 100 \
  --lr_num_cycles 1 \
  --beta1 0.9 \
  --beta2 0.95 \
  --weight_decay 1e-4 \
  --epsilon 1e-8 \
  --max_grad_norm 1.0"

# Validation arguments
validation_cmd="--validation_prompts \"$ID_TOKEN A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions.@@@49x512x768:::$ID_TOKEN A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage@@@49x512x768\" \
  --num_validation_videos 1 \
  --validation_steps 100"

# Miscellaneous arguments
miscellaneous_cmd="--tracker_name finetrainers-ltxv \
  --output_dir $OUTPUT_DIR \
  --nccl_timeout 1800 \
  --report_to wandb"

cmd="accelerate launch --config_file accelerate_configs/uncompiled_2.yaml --gpu_ids $GPU_IDS train.py \
  $model_cmd \
  $dataset_cmd \
  $dataloader_cmd \
  $diffusion_cmd \
  $training_cmd \
  $optimizer_cmd \
  $validation_cmd \
  $miscellaneous_cmd"

echo "Running command: $cmd"
eval $cmd
echo -ne "-------------------- Finished executing script --------------------\n\n"

Inference:

Assuming your LoRA is saved and pushed to the HF Hub, and named my-awesome-name/my-awesome-lora, we can now use the finetuned model for inference:

import torch
from diffusers import LTXPipeline
from diffusers.utils import export_to_video

pipe = LTXPipeline.from_pretrained(
    "Lightricks/LTX-Video", torch_dtype=torch.bfloat16
).to("cuda")
+ pipe.load_lora_weights("my-awesome-name/my-awesome-lora", adapter_name="ltxv-lora")
+ pipe.set_adapters(["ltxv-lora"], [0.75])

video = pipe("<my-awesome-prompt>").frames[0]
export_to_video(video, "output.mp4", fps=8)

Memory Usage

LoRA with rank 128, batch size 1, gradient checkpointing, optimizer adamw, 49x512x768 resolution, without precomputation:

Training configuration: {
    "trainable parameters": 117440512,
    "total samples": 69,
    "train epochs": 1,
    "train steps": 10,
    "batches per device": 1,
    "total batches observed per epoch": 69,
    "train batch size": 1,
    "gradient accumulation steps": 1
}
stage memory_allocated max_memory_reserved
before training start 13.486 13.879
before validation start 14.146 17.623
after validation end 14.146 17.623
after epoch 1 14.146 17.623
after training end 4.461 17.623

Note: requires about 18 GB of VRAM without precomputation.

LoRA with rank 128, batch size 1, gradient checkpointing, optimizer adamw, 49x512x768 resolution, with precomputation:

Training configuration: {
    "trainable parameters": 117440512,
    "total samples": 1,
    "train epochs": 10,
    "train steps": 10,
    "batches per device": 1,
    "total batches observed per epoch": 1,
    "train batch size": 1,
    "gradient accumulation steps": 1
}
stage memory_allocated max_memory_reserved
after precomputing conditions 8.88 8.920
after precomputing latents 9.684 11.613
before training start 3.809 10.010
after epoch 1 4.26 10.916
before validation start 4.26 10.916
after validation end 13.924 17.262
after training end 4.26 14.314

Note: requires about 17.5 GB of VRAM with precomputation. If validation is not performed, the memory usage is reduced to 11 GB.

Hunyuan Video

Training:

#!/bin/bash

# export TORCH_LOGS="+dynamo,recompiles,graph_breaks"
# export TORCHDYNAMO_VERBOSE=1
export WANDB_MODE="offline"
export NCCL_P2P_DISABLE=1
export TORCH_NCCL_ENABLE_MONITORING=0
export FINETRAINERS_LOG_LEVEL=DEBUG

GPU_IDS="0,1,2,3,4,5,6,7"

DATA_ROOT="/path/to/dataset"
CAPTION_COLUMN="prompts.txt"
VIDEO_COLUMN="videos.txt"
OUTPUT_DIR="/path/to/models/hunyuan-video/hunyuan-video-loras/hunyuan-video_cakify_500_3e-5_constant_with_warmup"

ID_TOKEN="afkx"

# Model arguments
model_cmd="--model_name hunyuan_video \
  --pretrained_model_name_or_path hunyuanvideo-community/HunyuanVideo"

# Dataset arguments
dataset_cmd="--data_root $DATA_ROOT \
  --video_column $VIDEO_COLUMN \
  --caption_column $CAPTION_COLUMN \
  --id_token $ID_TOKEN \
  --video_resolution_buckets 17x512x768 49x512x768 61x512x768 \
  --caption_dropout_p 0.05"

# Dataloader arguments
dataloader_cmd="--dataloader_num_workers 0"

# Diffusion arguments
diffusion_cmd=""

# Training arguments
training_cmd="--training_type lora \
  --seed 42 \
  --mixed_precision bf16 \
  --batch_size 1 \
  --train_steps 500 \
  --rank 128 \
  --lora_alpha 128 \
  --target_modules to_q to_k to_v to_out.0 \
  --gradient_accumulation_steps 1 \
  --gradient_checkpointing \
  --checkpointing_steps 500 \
  --checkpointing_limit 2 \
  --enable_slicing \
  --enable_tiling"

# Optimizer arguments
optimizer_cmd="--optimizer adamw \
  --lr 2e-5 \
  --lr_scheduler constant_with_warmup \
  --lr_warmup_steps 100 \
  --lr_num_cycles 1 \
  --beta1 0.9 \
  --beta2 0.95 \
  --weight_decay 1e-4 \
  --epsilon 1e-8 \
  --max_grad_norm 1.0"

# Validation arguments
validation_cmd="--validation_prompts \"$ID_TOKEN A baker carefully cuts a green bell pepper cake on a white plate against a bright yellow background, followed by a strawberry cake with a similar slice of cake being cut before the interior of the bell pepper cake is revealed with the surrounding cake-to-object sequence.@@@49x512x768:::$ID_TOKEN A cake shaped like a Nutella container is carefully sliced, revealing a light interior, amidst a Nutella-themed setup, showcasing deliberate cutting and preserved details for an appetizing dessert presentation on a white base with accompanying jello and cutlery, highlighting culinary skills and creative cake designs.@@@49x512x768:::$ID_TOKEN A cake shaped like a Nutella container is carefully sliced, revealing a light interior, amidst a Nutella-themed setup, showcasing deliberate cutting and preserved details for an appetizing dessert presentation on a white base with accompanying jello and cutlery, highlighting culinary skills and creative cake designs.@@@61x512x768:::$ID_TOKEN A vibrant orange cake disguised as a Nike packaging box sits on a dark surface, meticulous in its detail and design, complete with a white swoosh and 'NIKE' logo. A person's hands, holding a knife, hover over the cake, ready to make a precise cut, amidst a simple and clean background.@@@61x512x768:::$ID_TOKEN A vibrant orange cake disguised as a Nike packaging box sits on a dark surface, meticulous in its detail and design, complete with a white swoosh and 'NIKE' logo. A person's hands, holding a knife, hover over the cake, ready to make a precise cut, amidst a simple and clean background.@@@97x512x768:::$ID_TOKEN A vibrant orange cake disguised as a Nike packaging box sits on a dark surface, meticulous in its detail and design, complete with a white swoosh and 'NIKE' logo. A person's hands, holding a knife, hover over the cake, ready to make a precise cut, amidst a simple and clean background.@@@129x512x768:::$ID_TOKEN A person with gloved hands carefully cuts a cake shaped like a Skittles bottle, beginning with a precise incision at the lid, followed by careful sequential cuts around the neck, eventually detaching the lid from the body, revealing the chocolate interior of the cake while showcasing the layered design's detail.@@@61x512x768:::$ID_TOKEN A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage@@@61x512x768\" \
  --num_validation_videos 1 \
  --validation_steps 100"

# Miscellaneous arguments
miscellaneous_cmd="--tracker_name finetrainers-hunyuan-video \
  --output_dir $OUTPUT_DIR \
  --nccl_timeout 1800 \
  --report_to wandb"

cmd="accelerate launch --config_file accelerate_configs/uncompiled_8.yaml --gpu_ids $GPU_IDS train.py \
  $model_cmd \
  $dataset_cmd \
  $dataloader_cmd \
  $diffusion_cmd \
  $training_cmd \
  $optimizer_cmd \
  $validation_cmd \
  $miscellaneous_cmd"

echo "Running command: $cmd"
eval $cmd
echo -ne "-------------------- Finished executing script --------------------\n\n"

Inference:

Assuming your LoRA is saved and pushed to the HF Hub, and named my-awesome-name/my-awesome-lora, we can now use the finetuned model for inference:

import torch
from diffusers import HunyuanVideoPipeline

import torch
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
from diffusers.utils import export_to_video

model_id = "hunyuanvideo-community/HunyuanVideo"
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
    model_id, subfolder="transformer", torch_dtype=torch.bfloat16
)
pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16)
pipe.load_lora_weights("my-awesome-name/my-awesome-lora", adapter_name="hunyuanvideo-lora")
pipe.set_adapters(["hunyuanvideo-lora"], [0.6])
pipe.vae.enable_tiling()
pipe.to("cuda")

output = pipe(
    prompt="A cat walks on the grass, realistic",
    height=320,
    width=512,
    num_frames=61,
    num_inference_steps=30,
).frames[0]
export_to_video(output, "output.mp4", fps=15)

Memory Usage

LoRA with rank 128, batch size 1, gradient checkpointing, optimizer adamw, 49x512x768 resolutions, without precomputation:

Training configuration: {
    "trainable parameters": 163577856,
    "total samples": 69,
    "train epochs": 1,
    "train steps": 10,
    "batches per device": 1,
    "total batches observed per epoch": 69,
    "train batch size": 1,
    "gradient accumulation steps": 1
}
stage memory_allocated max_memory_reserved
before training start 38.889 39.020
before validation start 39.747 56.266
after validation end 39.748 58.385
after epoch 1 39.748 40.910
after training end 25.288 40.910

Note: requires about 59 GB of VRAM when validation is performed.

LoRA with rank 128, batch size 1, gradient checkpointing, optimizer adamw, 49x512x768 resolutions, with precomputation:

Training configuration: {
    "trainable parameters": 163577856,
    "total samples": 1,
    "train epochs": 10,
    "train steps": 10,
    "batches per device": 1,
    "total batches observed per epoch": 1,
    "train batch size": 1,
    "gradient accumulation steps": 1
}
stage memory_allocated max_memory_reserved
after precomputing conditions 14.232 14.461
after precomputing latents 14.717 17.244
before training start 24.195 26.039
after epoch 1 24.83 42.387
before validation start 24.842 42.387
after validation end 39.558 46.947
after training end 24.842 41.039

Note: requires about 47 GB of VRAM with validation. If validation is not performed, the memory usage is reduced to about 42 GB.

If you would like to use a custom dataset, refer to the dataset preparation guide here.

Note

To lower memory requirements:

  • Use a DeepSpeed config to launch training (refer to accelerate_configs/deepspeed.yaml as an example).
  • Pass --precompute_conditions when launching training.
  • Pass --gradient_checkpointing when launching training.
  • Pass --use_8bit_bnb when launching training. Note that this is only applicable to Adam and AdamW optimizers.
  • Do not perform validation/testing. This saves a significant amount of memory, which can be used to focus solely on training if you're on smaller VRAM GPUs.

Memory requirements

CogVideoX LoRA Finetuning
THUDM/CogVideoX-2b THUDM/CogVideoX-5b
CogVideoX Full Finetuning
THUDM/CogVideoX-2b THUDM/CogVideoX-5b

Supported and verified memory optimizations for training include:

  • CPUOffloadOptimizer from torchao. You can read about its capabilities and limitations here. In short, it allows you to use the CPU for storing trainable parameters and gradients. This results in the optimizer step happening on the CPU, which requires a fast CPU optimizer, such as torch.optim.AdamW(fused=True) or applying torch.compile on the optimizer step. Additionally, it is recommended not to torch.compile your model for training. Gradient clipping and accumulation is not supported yet either.
  • Low-bit optimizers from bitsandbytes. TODO: to test and make torchao ones work
  • DeepSpeed Zero2: Since we rely on accelerate, follow this guide to configure your accelerate installation to enable training with DeepSpeed Zero2 optimizations.

Important

The memory requirements are reported after running the training/prepare_dataset.py, which converts the videos and captions to latents and embeddings. During training, we directly load the latents and embeddings, and do not require the VAE or the T5 text encoder. However, if you perform validation/testing, these must be loaded and increase the amount of required memory. Not performing validation/testing saves a significant amount of memory, which can be used to focus solely on training if you're on smaller VRAM GPUs.

If you choose to run validation/testing, you can save some memory on lower VRAM GPUs by specifying --enable_model_cpu_offload.

LoRA finetuning

Note

The memory requirements for image-to-video lora finetuning are similar to that of text-to-video on THUDM/CogVideoX-5b, so it hasn't been reported explicitly.

Additionally, to prepare test images for I2V finetuning, you could either generate them on-the-fly by modifying the script, or extract some frames from your training data using: ffmpeg -i input.mp4 -frames:v 1 frame.png, or provide a URL to a valid and accessible image.

AdamW

Note: Trying to run CogVideoX-5b without gradient checkpointing OOMs even on an A100 (80 GB), so the memory measurements have not been specified.

With train_batch_size = 1:

model lora rank gradient_checkpointing memory_before_training memory_before_validation memory_after_validation memory_after_testing
THUDM/CogVideoX-2b 16 False 12.945 43.764 46.918 24.234
THUDM/CogVideoX-2b 16 True 12.945 12.945 21.121 24.234
THUDM/CogVideoX-2b 64 False 13.035 44.314 47.469 24.469
THUDM/CogVideoX-2b 64 True 13.036 13.035 21.564 24.500
THUDM/CogVideoX-2b 256 False 13.095 45.826 48.990 25.543
THUDM/CogVideoX-2b 256 True 13.094 13.095 22.344 25.537
THUDM/CogVideoX-5b 16 True 19.742 19.742 28.746 38.123
THUDM/CogVideoX-5b 64 True 20.006 20.818 30.338 38.738
THUDM/CogVideoX-5b 256 True 20.771 22.119 31.939 41.537

With train_batch_size = 4:

model lora rank gradient_checkpointing memory_before_training memory_before_validation memory_after_validation memory_after_testing
THUDM/CogVideoX-2b 16 True 12.945 21.803 21.814 24.322
THUDM/CogVideoX-2b 64 True 13.035 22.254 22.254 24.572
THUDM/CogVideoX-2b 256 True 13.094 22.020 22.033 25.574
THUDM/CogVideoX-5b 16 True 19.742 46.492 46.492 38.197
THUDM/CogVideoX-5b 64 True 20.006 47.805 47.805 39.365
THUDM/CogVideoX-5b 256 True 20.771 47.268 47.332 41.008
AdamW (8-bit bitsandbytes)

Note: Trying to run CogVideoX-5b without gradient checkpointing OOMs even on an A100 (80 GB), so the memory measurements have not been specified.

With train_batch_size = 1:

model lora rank gradient_checkpointing memory_before_training memory_before_validation memory_after_validation memory_after_testing
THUDM/CogVideoX-2b 16 False 12.945 43.732 46.887 24.195
THUDM/CogVideoX-2b 16 True 12.945 12.945 21.430 24.195
THUDM/CogVideoX-2b 64 False 13.035 44.004 47.158 24.369
THUDM/CogVideoX-2b 64 True 13.035 13.035 21.297 24.357
THUDM/CogVideoX-2b 256 False 13.035 45.291 48.455 24.836
THUDM/CogVideoX-2b 256 True 13.035 13.035 21.625 24.869
THUDM/CogVideoX-5b 16 True 19.742 19.742 28.602 38.049
THUDM/CogVideoX-5b 64 True 20.006 20.818 29.359 38.520
THUDM/CogVideoX-5b 256 True 20.771 21.352 30.727 39.596

With train_batch_size = 4:

model lora rank gradient_checkpointing memory_before_training memory_before_validation memory_after_validation memory_after_testing
THUDM/CogVideoX-2b 16 True 12.945 21.734 21.775 24.281
THUDM/CogVideoX-2b 64 True 13.036 21.941 21.941 24.445
THUDM/CogVideoX-2b 256 True 13.094 22.020 22.266 24.943
THUDM/CogVideoX-5b 16 True 19.742 46.320 46.326 38.104
THUDM/CogVideoX-5b 64 True 20.006 46.820 46.820 38.588
THUDM/CogVideoX-5b 256 True 20.771 47.920 47.980 40.002
AdamW + CPUOffloadOptimizer (with gradient offloading)

Note: Trying to run CogVideoX-5b without gradient checkpointing OOMs even on an A100 (80 GB), so the memory measurements have not been specified.

With train_batch_size = 1:

model lora rank gradient_checkpointing memory_before_training memory_before_validation memory_after_validation memory_after_testing
THUDM/CogVideoX-2b 16 False 12.945 43.705 46.859 24.180
THUDM/CogVideoX-2b 16 True 12.945 12.945 21.395 24.180
THUDM/CogVideoX-2b 64 False 13.035 43.916 47.070 24.234
THUDM/CogVideoX-2b 64 True 13.035 13.035 20.887 24.266
THUDM/CogVideoX-2b 256 False 13.095 44.947 48.111 24.607
THUDM/CogVideoX-2b 256 True 13.095 13.095 21.391 24.635
THUDM/CogVideoX-5b 16 True 19.742 19.742 28.533 38.002
THUDM/CogVideoX-5b 64 True 20.006 20.006 29.107 38.785
THUDM/CogVideoX-5b 256 True 20.771 20.771 30.078 39.559

With train_batch_size = 4:

model lora rank gradient_checkpointing memory_before_training memory_before_validation memory_after_validation memory_after_testing
THUDM/CogVideoX-2b 16 True 12.945 21.709 21.762 24.254
THUDM/CogVideoX-2b 64 True 13.035 21.844 21.855 24.338
THUDM/CogVideoX-2b 256 True 13.094 22.020 22.031 24.709
THUDM/CogVideoX-5b 16 True 19.742 46.262 46.297 38.400
THUDM/CogVideoX-5b 64 True 20.006 46.561 46.574 38.840
THUDM/CogVideoX-5b 256 True 20.771 47.268 47.332 39.623
DeepSpeed (AdamW + CPU/Parameter offloading)

Note: Results are reported with gradient_checkpointing enabled, running on a 2x A100.

With train_batch_size = 1:

model memory_before_training memory_before_validation memory_after_validation memory_after_testing
THUDM/CogVideoX-2b 13.141 13.141 21.070 24.602
THUDM/CogVideoX-5b 20.170 20.170 28.662 38.957

With train_batch_size = 4:

model memory_before_training memory_before_validation memory_after_validation memory_after_testing
THUDM/CogVideoX-2b 13.141 19.854 20.836 24.709
THUDM/CogVideoX-5b 20.170 40.635 40.699 39.027

Full finetuning

Note

The memory requirements for image-to-video full finetuning are similar to that of text-to-video on THUDM/CogVideoX-5b, so it hasn't been reported explicitly.

Additionally, to prepare test images for I2V finetuning, you could either generate them on-the-fly by modifying the script, or extract some frames from your training data using: ffmpeg -i input.mp4 -frames:v 1 frame.png, or provide a URL to a valid and accessible image.

Note

Trying to run full finetuning without gradient checkpointing OOMs even on an A100 (80 GB), so the memory measurements have not been specified.

AdamW

With train_batch_size = 1:

model gradient_checkpointing memory_before_training memory_before_validation memory_after_validation memory_after_testing
THUDM/CogVideoX-2b True 16.396 33.934 43.848 37.520
THUDM/CogVideoX-5b True 30.061 OOM OOM OOM

With train_batch_size = 4:

model gradient_checkpointing memory_before_training memory_before_validation memory_after_validation memory_after_testing
THUDM/CogVideoX-2b True 16.396 38.281 48.341 37.544
THUDM/CogVideoX-5b True 30.061 OOM OOM OOM
AdamW (8-bit bitsandbytes)

With train_batch_size = 1:

model gradient_checkpointing memory_before_training memory_before_validation memory_after_validation memory_after_testing
THUDM/CogVideoX-2b True 16.396 16.447 27.555 27.156
THUDM/CogVideoX-5b True 30.061 52.826 58.570 49.541

With train_batch_size = 4:

model gradient_checkpointing memory_before_training memory_before_validation memory_after_validation memory_after_testing
THUDM/CogVideoX-2b True 16.396 27.930 27.990 27.326
THUDM/CogVideoX-5b True 16.396 66.648 66.705 48.828
AdamW + CPUOffloadOptimizer (with gradient offloading)

With train_batch_size = 1:

model gradient_checkpointing memory_before_training memory_before_validation memory_after_validation memory_after_testing
THUDM/CogVideoX-2b True 16.396 16.396 26.100 23.832
THUDM/CogVideoX-5b True 30.061 39.359 48.307 37.947

With train_batch_size = 4:

model gradient_checkpointing memory_before_training memory_before_validation memory_after_validation memory_after_testing
THUDM/CogVideoX-2b True 16.396 27.916 27.975 23.936
THUDM/CogVideoX-5b True 30.061 66.607 66.668 38.061
DeepSpeed (AdamW + CPU/Parameter offloading)

Note: Results are reported with gradient_checkpointing enabled, running on a 2x A100.

With train_batch_size = 1:

model memory_before_training memory_before_validation memory_after_validation memory_after_testing
THUDM/CogVideoX-2b 13.111 13.111 20.328 23.867
THUDM/CogVideoX-5b 19.762 19.998 27.697 38.018

With train_batch_size = 4:

model memory_before_training memory_before_validation memory_after_validation memory_after_testing
THUDM/CogVideoX-2b 13.111 21.188 21.254 23.869
THUDM/CogVideoX-5b 19.762 43.465 43.531 38.082

Note

  • memory_after_validation is indicative of the peak memory required for training. This is because apart from the activations, parameters and gradients stored for training, you also need to load the vae and text encoder in memory and spend some memory to perform inference. In order to reduce total memory required to perform training, one can choose not to perform validation/testing as part of the training script.

  • memory_before_validation is the true indicator of the peak memory required for training if you choose to not perform validation/testing.

Slaying OOMs with PyTorch