Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SD] Finalized the benchmark #677

Merged
merged 11 commits into from
Aug 3, 2023
107 changes: 84 additions & 23 deletions stable_diffusion/README.md

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ model:
log_every_t: 200
timesteps: 1000
first_stage_key: npy
first_stage_type: latents
first_stage_type: moments
cond_stage_key: txt
image_size: 64
channels: 4
Expand Down Expand Up @@ -38,7 +38,7 @@ model:
enabled: True
inception_weights_url: https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth
cache_dir: /checkpoints/inception
gt_path: /datasets/coco2014/val2014_512x512_30k_stats.npz
gt_path: /datasets/coco2014/val2014_30k_stats.npz
clip:
enabled: True
clip_version: "ViT-H-14"
Expand All @@ -47,7 +47,7 @@ model:
scheduler_config:
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 10000 ]
warm_up_steps: [ 1000 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
Expand Down Expand Up @@ -110,7 +110,7 @@ data:
train:
target: ldm.data.webdatasets.build_dataloader
params:
urls: /datasets/laion-400m/webdataset-latents-filtered/{00000..00771}.tar
urls: /datasets/laion-400m/webdataset-moments-filtered/{00000..00831}.tar
batch_size: 8
shuffle: 1000
partial: False
Expand Down Expand Up @@ -138,7 +138,7 @@ lightning:
enable_progress_bar: False
max_epochs: -1
max_steps: 10000000000000
val_check_interval: 1000 # TODO(ahmadki): final validation interval will be determined with RCPs
val_check_interval: 1000
enable_checkpointing: True
num_sanity_val_steps: 0
strategy:
Expand All @@ -150,4 +150,4 @@ lightning:
target: lightning.pytorch.callbacks.ModelCheckpoint
params:
save_top_k: -1
every_n_train_steps: 2000
every_n_train_steps: 1000
153 changes: 153 additions & 0 deletions stable_diffusion/configs/train_32x08x02.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
model:
base_learning_rate: 1.25e-7
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
parameterization: "v"
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: npy
first_stage_type: moments
cond_stage_key: txt
image_size: 64
channels: 4
cond_stage_trainable: false
conditioning_key: crossattn
monitor: steps
scale_factor: 0.18215
use_ema: False

load_vae: True
load_unet: False
load_encoder: True

validation_config:
sampler: "ddim" # plms, ddim, dpm
steps: 50
scale: 8.0
ddim_eta: 0.0
prompt_key: "caption"
image_fname_key: "image_id"

save_images:
enabled: False
base_output_dir: "/results/inference"
fid:
enabled: True
inception_weights_url: https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth
cache_dir: /checkpoints/inception
gt_path: /datasets/coco2014/val2014_30k_stats.npz
clip:
enabled: True
clip_version: "ViT-H-14"
cache_dir: /checkpoints/clip

scheduler_config:
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 1000 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]

unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: False # gradient checkpointing
use_fp16: True
image_size: 32
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64 # need to fix for flash-attn
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
legacy: False

first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity

cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params:
arch: "ViT-H-14"
version: "laion2b_s32b_b79k"
freeze: True
layer: "penultimate"
cache_dir: /checkpoints/clip

data:
target: ldm.data.composable_data_module.ComposableDataModule
params:
train:
target: ldm.data.webdatasets.build_dataloader
params:
urls: /datasets/laion-400m/webdataset-moments-filtered/{00000..00831}.tar
batch_size: 2
shuffle: 1000
partial: False
keep_only_keys: ["npy", "txt"]
num_workers: 4
persistent_workers: True

validation:
target: ldm.data.tsv.build_dataloader
params:
annotations_file: "/datasets/coco2014/val2014_30k.tsv"
keys: ["image_id", "id", "caption"]
batch_size: 8
shuffle: False
num_workers: 1

lightning:
trainer:
accelerator: 'gpu'
num_nodes: 32
devices: 8
precision: 16
logger: False
log_every_n_steps: 5
enable_progress_bar: False
max_epochs: -1
max_steps: 10000000000000
val_check_interval: 1000
enable_checkpointing: True
num_sanity_val_steps: 0
strategy:
target: strategies.DDPStrategy
params:
find_unused_parameters: False

modelcheckpoint:
target: lightning.pytorch.callbacks.ModelCheckpoint
params:
save_top_k: -1
every_n_train_steps: 1000
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ model:
enabled: True
inception_weights_url: https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth
cache_dir: /checkpoints/inception
gt_path: /datasets/coco2014/val2014_512x512_30k_stats.npz
gt_path: /datasets/coco2014/val2014_30k_stats.npz
clip:
enabled: True
clip_version: "ViT-H-14"
Expand All @@ -47,7 +47,7 @@ model:
scheduler_config:
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 10000 ]
warm_up_steps: [ 1000 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
Expand Down Expand Up @@ -110,8 +110,8 @@ data:
train:
target: ldm.data.webdatasets.build_dataloader
params:
urls: /datasets/laion-400m/webdataset-filtered/{00000..00771}.tar
batch_size: 8
urls: /datasets/laion-400m/webdataset-filtered/{00000..00831}.tar
batch_size: 2
shuffle: 1000
partial: False
decode: pil
Expand Down Expand Up @@ -141,15 +141,15 @@ data:
lightning:
trainer:
accelerator: 'gpu'
num_nodes: 1
num_nodes: 32
devices: 8
precision: 16
logger: False
log_every_n_steps: 5
enable_progress_bar: False
max_epochs: -1
max_steps: 10000000000000
val_check_interval: 1000 # TODO(ahmadki): final validation interval will be determined with RCPs
val_check_interval: 1000
enable_checkpointing: True
num_sanity_val_steps: 0
strategy:
Expand All @@ -161,4 +161,4 @@ lightning:
target: lightning.pytorch.callbacks.ModelCheckpoint
params:
save_top_k: -1
every_n_train_steps: 2000
every_n_train_steps: 1000
Loading
Loading