Skip to content

Commit

Permalink
add gradient checkpointing option to docs
Browse files Browse the repository at this point in the history
  • Loading branch information
bghira committed Dec 3, 2024
1 parent 96d477e commit 8f8e1cf
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 0 deletions.
9 changes: 9 additions & 0 deletions OPTIONS.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,15 @@ The script `configure.py` in the project root can be used via `python configure.
- **What**: Path to the pretrained T5 model or its identifier from https://huggingface.co/models.
- **Why**: When training PixArt, you might want to use a specific source for your T5 weights so that you can avoid downloading them multiple times when switching the base model you train from.

### `--gradient_checkpointing`

- **What**: During training, gradients will be calculated layerwise and accumulated to save on peak VRAM requirements at the cost of slower training.

### `--gradient_checkpointing_interval`

- **What**: Checkpoint only every _n_ blocks, where _n_ is a value greater than zero. A value of 1 is effectively the same as just leaving `--gradient_checkpointing` enabled, and a value of 2 will checkpoint every other block.
- **Note**: SDXL and Flux are currently the only models supporting this option. SDXL uses a hackish implementation.

### `--refiner_training`

- **What**: Enables training a custom mixture-of-experts model series. See [Mixture-of-Experts](/documentation/MIXTURE_OF_EXPERTS.md) for more information on these options.
Expand Down
5 changes: 5 additions & 0 deletions documentation/quickstart/FLUX.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ There, you will possibly need to modify the following variables:
- This option causes update steps to be accumulated over several steps. This will increase the training runtime linearly, such that a value of 2 will make your training run half as quickly, and take twice as long.
- `optimizer` - Beginners are recommended to stick with adamw_bf16, though optimi-lion and optimi-stableadamw are also good choices.
- `mixed_precision` - Beginners should keep this in `bf16`
- `gradient_checkpointing` - set this to true in practically every situation on every device
- `gradient_checkpointing_interval` - this could be set to a value of 2 or higher on larger GPUs to only checkpoint every _n_ blocks. A value of 2 would checkpoint half of the blocks, and 3 would be one-third.

Multi-GPU users can reference [this document](/OPTIONS.md#environment-configuration-variables) for information on configuring the number of GPUs to use.

Expand Down Expand Up @@ -415,6 +417,9 @@ Currently, the lowest VRAM utilisation (9090M) can be attained with:
- PyTorch: 2.6 Nightly (Sept 29th build)
- Using `--quantize_via=cpu` to avoid outOfMemory error during startup on <=16G cards.
- With `--attention_mechanism=sageattention` to further reduce VRAM by 0.1GB and improve training validation image generation speed.
- Be sure to enable `--gradient_checkpointing` or nothing you do will stop it from OOMing

**NOTE**: Pre-caching of VAE embeds and text encoder outputs may use more memory and still OOM. If so, text encoder quantisation and VAE tiling can be enabled.

Speed was approximately 1.4 iterations per second on a 4090.

Expand Down

0 comments on commit 8f8e1cf

Please sign in to comment.