Skip to content

Commit

Permalink
Add Flux and Flux Controlnet Support to Diffusion folder (#11794)
Browse files Browse the repository at this point in the history
* Vae added and matched flux checkpoint

Signed-off-by: mingyuanm <[email protected]>

* Flux model added.

Signed-off-by: mingyuanm <[email protected]>

* Copying FlowMatchEulerScheduler over

Signed-off-by: mingyuanm <[email protected]>

* WIP: Start to test the pipeline forward pass

Signed-off-by: mingyuanm <[email protected]>

* Vae added and matched flux checkpoint

Signed-off-by: mingyuanm <[email protected]>

* Inference pipeline runs with offloading function

Signed-off-by: mingyuanm <[email protected]>

* Start to test image generation

Signed-off-by: mingyuanm <[email protected]>

* Decoding with VAE part has been verified. Still need to check the denoising loop.

Signed-off-by: mingyuanm <[email protected]>

* The inference pipeline is verified.

Signed-off-by: mingyuanm <[email protected]>

* Add arg parsers and refactoring

Signed-off-by: mingyuanm <[email protected]>

* Tested on multi batch sizes and prompts.

Signed-off-by: mingyuanm <[email protected]>

* Add headers

Signed-off-by: mingyuanm <[email protected]>

* Apply isort and black reformatting

Signed-off-by: Victor49152 <[email protected]>

* Renaming

Signed-off-by: mingyuanm <[email protected]>

* Move shceduler to sampler folder

Signed-off-by: mingyuanm <[email protected]>

* Merging folders.

Signed-off-by: mingyuanm <[email protected]>

* Apply isort and black reformatting

Signed-off-by: Victor49152 <[email protected]>

* Tested after path changing.

Signed-off-by: mingyuanm <[email protected]>

* Apply isort and black reformatting

Signed-off-by: Victor49152 <[email protected]>

* Move MMDIT block to NeMo

Signed-off-by: mingyuanm <[email protected]>

* Apply isort and black reformatting

Signed-off-by: Victor49152 <[email protected]>

* Add joint attention and single attention to NeMo

Signed-off-by: mingyuanm <[email protected]>

* Apply isort and black reformatting

Signed-off-by: Victor49152 <[email protected]>

* Joint attention updated

Signed-off-by: mingyuanm <[email protected]>

* Apply isort and black reformatting

Signed-off-by: Victor49152 <[email protected]>

* Remove redundant importing

Signed-off-by: mingyuanm <[email protected]>

* Refactor to inherit megatron module

Signed-off-by: mingyuanm <[email protected]>

* Adding mockdata

Signed-off-by: mingyuanm <[email protected]>

* DDP training works

Signed-off-by: mingyuanm <[email protected]>

* Added flux controlnet training components while not tested yet

Signed-off-by: mingyuanm <[email protected]>

* Flux training with DDP tested on 1 GPU

Signed-off-by: mingyuanm <[email protected]>

* Flux and controlnet now could train on precached mode.

Signed-off-by: mingyuanm <[email protected]>

* Custom FSDP path added to megatron parallel.

Signed-off-by: mingyuanm <[email protected]>

* Bug fix

Signed-off-by: mingyuanm <[email protected]>

* A hacky way to wrap frozen flux into FSDP to reproduce illegal memory issue.

Signed-off-by: mingyuanm <[email protected]>

* Typo

Signed-off-by: mingyuanm <[email protected]>

* Bypass the no grad issue when no single layers exists

Signed-off-by: mingyuanm <[email protected]>

* A hacky way to wrap frozen flux into FSDP to reproduce illegal memory issue.

Signed-off-by: mingyuanm <[email protected]>

* Let the flux model's dtype autocast before FSDP wrapping

* fix RuntimeError: "Output 0 of SliceBackward0 is a view and is being modified inplace..."

* Add a wrapper to flux controlnet so they are all wrapped into FSDP automatically

Signed-off-by: mingyuanm <[email protected]>

* Get rid of concat op in flux single transformer

Signed-off-by: mingyuanm <[email protected]>

* Get rid of concat op in flux single transformer

Signed-off-by: mingyuanm <[email protected]>

* single block attention.linear_proj.bias must not require grads after refactoring

Signed-off-by: mingyuanm <[email protected]>

* use cpu initialization to avoid OOM

Signed-off-by: mingyuanm <[email protected]>

* Set up flux training script with tp

Signed-off-by: mingyuanm <[email protected]>

* SDXL fid image generation script updated.

Signed-off-by: mingyuanm <[email protected]>

* Mcore self attention API changed

Signed-off-by: mingyuanm <[email protected]>

* Add a dummy task encoder for raw image inputs

Signed-off-by: mingyuanm <[email protected]>

* Support loading crudedataset via energon dataloader

Signed-off-by: mingyuanm <[email protected]>

* Default save last to True

Signed-off-by: mingyuanm <[email protected]>

* Add controlnet inference pipeline

Signed-off-by: mingyuanm <[email protected]>

* Add controlnet inference script

Signed-off-by: mingyuanm <[email protected]>

* Image resize mode update

Signed-off-by: mingyuanm <[email protected]>

* Remove unnecessary bias to avoid sharding issue.

Signed-off-by: mingyuanm <[email protected]>

* Handle MCore custom fsdp checkpoint load (#11621)

* general handle custom_fsdp checkpoint load

* Apply isort and black reformatting

Signed-off-by: shjwudp <[email protected]>

* Apply isort and black reformatting

Signed-off-by: artbataev <[email protected]>

---------

Signed-off-by: shjwudp <[email protected]>
Signed-off-by: artbataev <[email protected]>
Co-authored-by: shjwudp <[email protected]>
Co-authored-by: artbataev <[email protected]>

* Checkpoint naming

Signed-off-by: mingyuanm <[email protected]>

* Image logger WIP

Signed-off-by: mingyuanm <[email protected]>

* Image logger works fine

Signed-off-by: mingyuanm <[email protected]>

* save hint and output to image logger.

Signed-off-by: mingyuanm <[email protected]>

* Update flux controlnet training step

Signed-off-by: mingyuanm <[email protected]>

* Add model connector and try to load from dist ckpt but failed.

Signed-off-by: mingyuanm <[email protected]>

* Renaming and refactoring submodel configs for nemo run compatibility

Signed-off-by: mingyuanm <[email protected]>

* Nemo run script works for basic testing recipe

Signed-off-by: mingyuanm <[email protected]>

* Added tp2 training factory

Signed-off-by: mingyuanm <[email protected]>

* Added convergence recipe

Signed-off-by: mingyuanm <[email protected]>

* Added flux training scripts

Signed-off-by: mingyuanm <[email protected]>

* Inference script tested

Signed-off-by: mingyuanm <[email protected]>

* Controlnet inference script tested

Signed-off-by: mingyuanm <[email protected]>

* Moving scripts to correct folder and modify headers

Signed-off-by: mingyuanm <[email protected]>

* Apply isort and black reformatting

Signed-off-by: Victor49152 <[email protected]>

* Doc strings update

Signed-off-by: mingyuanm <[email protected]>

* Apply isort and black reformatting

Signed-off-by: Victor49152 <[email protected]>

* pylint correction

Signed-off-by: mingyuanm <[email protected]>

* Apply isort and black reformatting

Signed-off-by: Victor49152 <[email protected]>

* Add import guard since custom fsdp is not merged to mcore yet

Signed-off-by: mingyuanm <[email protected]>

* Add copy right headers and correct code check

Signed-off-by: mingyuanm <[email protected]>

* Apply isort and black reformatting

Signed-off-by: Victor49152 <[email protected]>

* Code Scan

Signed-off-by: mingyuanm <[email protected]>

* Minor fix

Signed-off-by: mingyuanm <[email protected]>

---------

Signed-off-by: mingyuanm <[email protected]>
Signed-off-by: Victor49152 <[email protected]>
Signed-off-by: shjwudp <[email protected]>
Signed-off-by: artbataev <[email protected]>
Co-authored-by: Victor49152 <[email protected]>
Co-authored-by: jianbinc <[email protected]>
Co-authored-by: shjwudp <[email protected]>
Co-authored-by: artbataev <[email protected]>
  • Loading branch information
5 people authored Jan 21, 2025
1 parent 5b4d091 commit 066e4b4
Show file tree
Hide file tree
Showing 31 changed files with 3,322 additions and 384 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,15 @@ name: stable-diffusion-train

fid:
classifier_free_guidance:
- 1.5
- 2
- 3
- 4
- 5
- 6
- 7
- 8
nnodes_per_cfg: 1
nnodes_per_cfg: 2
ntasks_per_node: 8
local_task_id: null
num_images_to_eval: 30000
coco_captions_path: /coco2014/coco2014_val_sampled_30k/captions
coco_images_path: /coco2014/coco2014_val/images_256
coco_captions_path: /datasets/coco2014/coco2014_val_sampled_30k/captions
coco_images_path: /datasets/coco2014/coco2014_val/images_256
save_path: output

model:
restore_from_path:
is_legacy: False

use_refiner: False
use_fp16: False # use fp16 model weights
Expand Down Expand Up @@ -88,8 +78,128 @@ sampling:
order: 4

trainer:
devices: ${evaluation.fid.ntasks_per_node}
devices: ${fid.ntasks_per_node}
num_nodes: 1
accelerator: gpu
precision: 32
logger: False # logger provided by exp_manager


model:
restore_from_path: null
is_legacy: False
scale_factor: 0.13025
disable_first_stage_autocast: True

fsdp: False
fsdp_set_buffer_dtype: null
fsdp_sharding_strategy: 'full'
use_cpu_initialization: True

optim:
name: fused_adam
lr: 1e-4
weight_decay: 0.0
betas:
- 0.9
- 0.999
sched:
name: WarmupHoldPolicy
warmup_steps: 10
hold_steps: 10000000000000 # Incredibly large value to hold the lr as constant

denoiser_config:
_target_: nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.denoiser.DiscreteDenoiser
num_idx: 1000

weighting_config:
_target_: nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config:
_target_: nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.denoiser_scaling.EpsScaling
discretization_config:
_target_: nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.discretizer.LegacyDDPMDiscretization

unet_config:
_target_: nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.openaimodel.UNetModel
from_pretrained:
from_NeMo: True
adm_in_channels: 2816
num_classes: sequential
use_checkpoint: False
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4 ]
num_head_channels: 64
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: [ 1, 2, 10 ] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
context_dim: 2048
image_size: 64 # unused
# spatial_transformer_attn_type: softmax #note: only default softmax is supported now
legacy: False
use_flash_attention: False

first_stage_config:
# _target_: nemo.collections.multimodal.models.stable_diffusion.ldm.autoencoder.AutoencoderKLInferenceWrapper
_target_: nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.autoencoder.AutoencoderKLInferenceWrapper
from_pretrained:
from_NeMo: False
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [ 1, 2, 4, 4 ]
num_res_blocks: 2
attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity

conditioner_config:
_target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.GeneralConditioner
emb_models:
# crossattn cond
- is_trainable: False
input_key: txt
emb_model:
_target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.FrozenCLIPEmbedder
layer: hidden
layer_idx: 11
# crossattn and vector cond
- is_trainable: False
input_key: txt
emb_model:
_target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.FrozenOpenCLIPEmbedder2
arch: ViT-bigG-14
version: laion2b_s39b_b160k
freeze: True
layer: penultimate
always_return_pooled: True
legacy: False
# vector cond
- is_trainable: False
input_key: original_size_as_tuple
emb_model:
_target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.ConcatTimestepEmbedderND
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: crop_coords_top_left
emb_model:
_target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.ConcatTimestepEmbedderND
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: target_size_as_tuple
emb_model:
_target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.ConcatTimestepEmbedderND
outdim: 256 # multiplied by two
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ model:
scale_factor: 0.13025
disable_first_stage_autocast: True
is_legacy: False
restore_from_path: ""
restore_from_path: null

fsdp: False
fsdp_set_buffer_dtype: null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@
from nemo.core.config import hydra_runner


@hydra_runner(config_path='conf/stable_diffusion/conf', config_name='sd_xl_fid_images')
@hydra_runner(config_path='conf', config_name='sd_xl_fid_images')
def main(cfg):
# pylint: disable=C0116
# Read configuration parameters
nnodes_per_cfg = cfg.fid.nnodes_per_cfg
ntasks_per_node = cfg.fid.ntasks_per_node
Expand Down
Loading

0 comments on commit 066e4b4

Please sign in to comment.