Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add gradient checkpointing option to docs #1185

Merged
merged 1 commit into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions OPTIONS.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,15 @@ The script `configure.py` in the project root can be used via `python configure.
- **What**: Path to the pretrained T5 model or its identifier from https://huggingface.co/models.
- **Why**: When training PixArt, you might want to use a specific source for your T5 weights so that you can avoid downloading them multiple times when switching the base model you train from.

### `--gradient_checkpointing`

- **What**: During training, gradients will be calculated layerwise and accumulated to save on peak VRAM requirements at the cost of slower training.

### `--gradient_checkpointing_interval`

- **What**: Checkpoint only every _n_ blocks, where _n_ is a value greater than zero. A value of 1 is effectively the same as just leaving `--gradient_checkpointing` enabled, and a value of 2 will checkpoint every other block.
- **Note**: SDXL and Flux are currently the only models supporting this option. SDXL uses a hackish implementation.

### `--refiner_training`

- **What**: Enables training a custom mixture-of-experts model series. See [Mixture-of-Experts](/documentation/MIXTURE_OF_EXPERTS.md) for more information on these options.
Expand Down
5 changes: 5 additions & 0 deletions documentation/quickstart/FLUX.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ There, you will possibly need to modify the following variables:
- This option causes update steps to be accumulated over several steps. This will increase the training runtime linearly, such that a value of 2 will make your training run half as quickly, and take twice as long.
- `optimizer` - Beginners are recommended to stick with adamw_bf16, though optimi-lion and optimi-stableadamw are also good choices.
- `mixed_precision` - Beginners should keep this in `bf16`
- `gradient_checkpointing` - set this to true in practically every situation on every device
- `gradient_checkpointing_interval` - this could be set to a value of 2 or higher on larger GPUs to only checkpoint every _n_ blocks. A value of 2 would checkpoint half of the blocks, and 3 would be one-third.

Multi-GPU users can reference [this document](/OPTIONS.md#environment-configuration-variables) for information on configuring the number of GPUs to use.

Expand Down Expand Up @@ -415,6 +417,9 @@ Currently, the lowest VRAM utilisation (9090M) can be attained with:
- PyTorch: 2.6 Nightly (Sept 29th build)
- Using `--quantize_via=cpu` to avoid outOfMemory error during startup on <=16G cards.
- With `--attention_mechanism=sageattention` to further reduce VRAM by 0.1GB and improve training validation image generation speed.
- Be sure to enable `--gradient_checkpointing` or nothing you do will stop it from OOMing

**NOTE**: Pre-caching of VAE embeds and text encoder outputs may use more memory and still OOM. If so, text encoder quantisation and VAE tiling can be enabled.

Speed was approximately 1.4 iterations per second on a 4090.

Expand Down
Loading