diff --git a/stable_diffusion/README.md b/stable_diffusion/README.md
index 605bbeaa5..d7b983538 100644
--- a/stable_diffusion/README.md
+++ b/stable_diffusion/README.md
@@ -1,5 +1,7 @@
Text to image: Stable Diffusion (SD)
+![](imgs/overview.png)
+
Welcome to the reference implementation for the MLPerf text to image
benchmark, utilizing the Stable Diffusion (SD) model.
Our repository prioritizes transparency, reproducibility, reliability,
@@ -31,6 +33,7 @@ understand the fundamentals of the Stable Diffusion model.
- [FID](#fid)
- [CLIP](#clip)
- [Reference runs](#reference-runs)
+- [Rules](#rules)
- [BibTeX](#bibtex)
# Getting started
@@ -96,18 +99,14 @@ The benchmark employs two datasets:
1. Validation: a subset of [coco-2014 validation](https://cocodataset.org/#download)
### Laion 400m
-**TODO(ahmadki): This README presumes that the training dataset is Laion-400m. However, the final dataset choice will be decided at the time of the RCP submission.**
-
The benchmark uses a CC-BY licensed subset of the Laion400 dataset.
The LAION datasets comprise lists of URLs for original images, paired with the ALT text linked to those images. As downloading millions of images from the internet is not a deterministic process and to ensure the replicability of the benchmark results, submitters are asked to download the subset from the MLCommons storage. The dataset is provided in two formats:
-**TODO(ahmadki): The scripts will be added once the dataset is uploaded to the MLCommons storage.**
-
-1. Preprocessed latents (recommended):`scripts/datasets/download_laion400m-ccby-latents.sh --output-dir /datasets/laion-400m/ccby_latents_512x512`
-2. Raw images: `scripts/datasets/download_laion400m-ccby-images.sh --output-dir /datasets/laion-400m/ccby_images`
+1. Preprocessed moments (recommended):`scripts/datasets/laion400m-filtered-download-moments.sh --output-dir /datasets/laion-400m/webdataset-moments-filtered`
+2. Raw images: `scripts/datasets/laion400m-filtered-download-images.sh --output-dir /datasets/laion-400m/webdataset-filtered`
-While the benchmark code is compatible with both formats, we recommend using the preprocessed latents to save on computational resources.
+While the benchmark code is compatible with both formats, we recommend using the preprocessed moments to save on computational resources.
For additional information about Laion 400m, the CC-BY subset, and the scripts used for downloading, filtering, and preprocessing the images, refer to the laion-400m section [here](#the-laion400m-subset).
@@ -117,7 +116,8 @@ The COCO-2014-validation dataset consists of 40,504 images and 202,654 annotatio
To ensure reproducibility, we ask the submitters to download the relevant files from the MLCommons storage:
```bash
-scripts/datasets/download_coco-2014.sh --output-dir /datasets/coco2014/val2014-sd
+scripts/datasets/coco2014-validation-download-prompts.sh --output-dir /datasets/coco2014
+scripts/datasets/coco2014-validation-download-stats.sh --output-dir /datasets/coco2014
```
While the benchmark code can work with raw images, we recommend using the preprocessed inception weights to save on computational resources.
@@ -134,7 +134,7 @@ The benchmark utilizes several network architectures for both the training and v
```bash
scripts/checkpoints/download_sd.sh --output-dir /checkpoints/sd
```
-2. **Inception**: The Inception network is employed during validation to compute the Frechet Inception Distance (FID) score. The necessary weights can be acquired with the following command:
+2. **Inception**: The Inception network is employed during validation to compute the Fréchet Inception Distance (FID) score. The necessary weights can be downloaded with the following command:
```bash
scripts/checkpoints/download_inception.sh --output-dir /checkpoints/inception
```
@@ -158,19 +158,19 @@ To initiate a single-node training run, execute the following command from withi
--gpus-per-node 8 \
--checkpoint /checkpoints/sd/512-base-ema.ckpt \
--results-dir /results \
- --config configs/train_512_latents.yaml
+ --config configs/train_01x08x08.yaml
```
-If you prefer to train using raw images, consider utilizing the `configs/train_512.yaml` configuration file.
+If you prefer to train using raw images, consider utilizing the `configs/train_32x08x02_raw_images.yaml` configuration file.
### Multi-node (with SLURM)
Given the extended duration it typically takes to train the Stable Diffusion model, it's often beneficial to employ multiple nodes for expedited training. For this purpose, we provide rudimentary Slurm scripts to submit multi-node training batch jobs. Use the following command to submit a batch job:
```bash
scripts/slurm/sbatch.sh \
- --num-nodes 8 \
+ --num-nodes 32 \
--gpus-per-node 8 \
--checkpoint /checkpoints/sd/512-base-ema.ckpt \
- --config configs/train_512_latents.yaml \
- --results-dir configs/train_512_latents.yaml \
+ --config configs/train_512_moments.yaml \
+ --results-dir configs/train_32x08x02.yaml \
--container mlperf/stable_diffusion
```
@@ -178,21 +178,19 @@ Given the substantial variability among Slurm clusters, users are encouraged to
In any case, the dataset and checkpoints are expected to be available to all the nodes.
-
# Benchmark details
## The datasets
-**TODO(ahmadki): Please note that Laion-400m is being used as a placeholder; the final decision regarding the training dataset has not yet been made.**
### Laion 400m
[Laion-400m](#[the-model](https://laion.ai/blog/laion-400-open-dataset/)) is a rich dataset of 400 million image-text pairs, crafted by the Laion project. The benchmark uses a relatively small subset of this dataset, approximately 6.1M images, all under a CC-BY license.
-To establish a fair benchmark and assure reproducibility of results, we request that submitters download either the preprocessed latents or raw images from the MLCommons storage using the scripts provided [here](#laion-400m). These images and latents were generated by following these steps:
+To establish a fair benchmark and assure reproducibility of results, we request that submitters download either the preprocessed moments or raw images from the MLCommons storage using the scripts provided [here](#laion-400m). These images and moments were generated by following these steps:
1. Download the metadata: `scripts/datasets/laion400m-download-metadata.sh --output-dir /datasets/laion-400m/metadata`
2. Filter the metadata based on LICENSE information: `scripts/datasets/laion400m-filter-metadata.sh --input-metadata-dir /datasets/laion-400m/metadata --output-metadata-dir /datasets/laion-400m/metadata-filtered`
3. Fownload the filtered subset: `scripts/datasets/laion400m-download-dataset --metadata-dir /datasets/laion-400m/metadata-filtered --output-dir /datasets/laion-400m/webdataset-filtered`
-4. Preprocess the images to latents `scripts/datasets/laion400m-convert-images-to-latents.sh --input-folder /datasets/laion-400m/webdataset-filtered --output-dir /datasets/laion-400m/webdataset-latents-filtered`
+4. Preprocess the images to latents `scripts/datasets/laion400m-convert-images-to-moments.sh --input-folder /datasets/laion-400m/webdataset-filtered --output-dir /datasets/laion-400m/webdataset-moments-filtered`
### COCO 2014
@@ -210,12 +208,15 @@ We achieved this subset by following these steps:
Stable Diffusion v2 is a latent diffusion model which combines an autoencoder with a diffusion model that is trained in the latent space of the autoencoder. During training:
* Images are encoded through an encoder, which turns images into latent representations. The autoencoder uses a relative downsampling factor of 8 and maps images of shape 512 x 512 x 3 to latents of shape 64 x 64 x 4
-* Text prompts are encoded through the OpenCLIP-ViT/H text-encoder.
+* Text prompts are encoded through the OpenCLIP-ViT/H text-encoder, the output embedding vector has a lengh of 1024.
* The output of the text encoder is fed into the UNet backbone of the latent diffusion model via cross-attention.
* The loss is a reconstruction objective between the noise that was added to the latent and the prediction made by the UNet. We also use the so-called v-objective, see https://arxiv.org/abs/2202.00512.
The UNet backbone in our model serves as the sole trainable component, initialized from random weights. Conversely, the weights of both the image and text encoders are loaded from a pre-existing checkpoint and kept static throughout the training procedure
+Although our benchmark aims to adhere to the original Stable Diffusion v2 implementation as closely as possible, it's important to note some key deviations:
+1. The group norm of the UNet within our code uses a group size of 16 instead of the 32 used in the original implementation. This adjustment can be found in our code at this [link](https://github.com/ahmadki/training/blob/master/stable_diffusion/ldm/modules/diffusionmodules/util.py#L209)
+
### UNet
TODO(ahmadki): give an overview
### VAE
@@ -224,12 +225,72 @@ TODO(ahmadki): give an overview
TODO(ahmadki): give an overview
## Validation metrics
### FID
-TODO(ahmadki): give an overview
+FID is a measure of similarity between two datasets of images. It was shown to correlate well with human judgement of visual quality and is most often used to evaluate the quality of samples of Generative Adversarial Networks. FID is calculated by computing the [Fréchet distance](https://en.wikipedia.org/wiki/Fr%C3%A9chet_distance) between two Gaussians fitted to feature representations of the Inception network. A lower FID implies a better image quality.
+
+Further insights and an independent evaluation of the FID score can be found in [Are GANs Created Equal? A Large-Scale Study.](https://arxiv.org/abs/1711.10337)
+
### CLIP
-TODO(ahmadki): give an overview
+CLIP is a reference free metric that can be used to evaluate the correlation between a caption for an image and the actual content of the image, it has been found to be highly correlated with human judgement. A higher CLIP Score implies that the caption matches closer to image.
# Reference runs
-TODO(ahmadki): with RCPs
+The benchmark is expected to have the following convergence profile:
+
+Using `configs/train_32x08x02.yaml`:
+
+| | | Run 1 | | Run 2 | | Run 3 | | Run 4 | | Run 5 | | Run 6 | | Run 7 | | Run 8 | | Run 9 | | Run 10 | | Run 11 | | Run 12 | | Run 13 | | Run 14 | |
+|---------------|-----------|------------------|------------------|------------------|-------------------|------------------|------------------|------------------|------------------|------------------|-----------------|------------------|-------------------|------------------|-------------------|------------------|-------------------|------------------|-------------------|------------------|-------------------|------------------|------------------|------------------|-------------------|------------------|-------------------|------------------|------------------|
+| # Iterations | # Images | FID | CLIP | FID | CLIP | FID | CLIP | FID | CLIP | FID | CLIP | FID | CLIP | FID | CLIP | FID | CLIP | FID | CLIP | FID | CLIP | FID | CLIP | FID | CLIP | FID | CLIP | FID | CLIP |
+| 1000 | 512000 | 240.307753187222 | 0.07037353515625 | 294.573046200695 | 0.042327880859375 | 241.683981119201 | 0.0513916015625 | 244.3093813921 | 0.051513671875 | 223.315586246326 | 0.0750732421875 | 226.574970212965 | 0.057525634765625 | 251.340068446753 | 0.052337646484375 | 248.620622416789 | 0.059051513671875 | 239.375032509416 | 0.052093505859375 | 235.714879858368 | 0.060089111328125 | 216.554181677026 | 0.07000732421875 | 254.65308599743 | 0.047882080078125 | 246.05928942625 | 0.050201416015625 | 245.287937367888 | 0.06011962890625 |
+| 2000 | 1024000 | 160.470078270219 | 0.095458984375 | 151.759951633367 | 0.11041259765625 | 153.183910033727 | 0.10162353515625 | 179.805348652356 | 0.08294677734375 | 149.675345136843 | 0.11279296875 | 187.266315858632 | 0.0858154296875 | 151.388620255932 | 0.11322021484375 | 146.280853966833 | 0.10528564453125 | 152.773085803242 | 0.1033935546875 | 166.303450203846 | 0.0897216796875 | 157.262003859026 | 0.10089111328125 | 163.671906693434 | 0.1072998046875 | 150.59871834863 | 0.10504150390625 | 168.392019999909 | 0.10430908203125 |
+| 3000 | 1536000 | 116.104976258075 | 0.1297607421875 | 104.82406006546 | 0.1458740234375 | 113.28653970832 | 0.13720703125 | 108.181491001522 | 0.1358642578125 | 120.569476481641 | 0.1312255859375 | 121.507471335783 | 0.1385498046875 | 112.246994199521 | 0.134765625 | 118.687092400973 | 0.1241455078125 | 104.768844637884 | 0.1370849609375 | 109.305393022732 | 0.1436767578125 | 108.186119966125 | 0.1424560546875 | 114.106200291533 | 0.130615234375 | 110.985516643455 | 0.1348876953125 | 128.967376625536 | 0.11767578125 |
+| 4000 | 2048000 | 98.6194527639151 | 0.1490478515625 | 93.3023418189487 | 0.1614990234375 | 98.470530998531 | 0.147705078125 | 104.506561457867 | 0.1497802734375 | 117.53438644301 | 0.1346435546875 | 93.9551013202581 | 0.156494140625 | 96.7522546886262 | 0.14990234375 | 93.6238405768778 | 0.14892578125 | 99.2660537095259 | 0.14697265625 | 92.4074720353253 | 0.15283203125 | 102.79009009653 | 0.15576171875 | 106.295253230265 | 0.1375732421875 | 106.89872594564 | 0.150390625 | 101.269312343302 | 0.1561279296875 |
+| 5000 | 2560000 | 77.9144665349304 | 0.1700439453125 | 84.3912640535151 | 0.1715087890625 | 74.3932278711953 | 0.1763916015625 | 81.8694462285154 | 0.1590576171875 | 90.8364642085304 | 0.1585693359375 | 85.7692721578059 | 0.1630859375 | 84.7254078568678 | 0.1646728515625 | 82.5745049916586 | 0.1634521484375 | 88.3845314314294 | 0.1591796875 | 76.2159562712626 | 0.1741943359375 | 86.4943102965646 | 0.170654296875 | 89.1185837064392 | 0.1600341796875 | 85.7690462394573 | 0.1629638671875 | 83.403400754612 | 0.162841796875 |
+| 6000 | 3072000 | 78.277784436582 | 0.175537109375 | 70.7521662926858 | 0.1824951171875 | 69.5446726528529 | 0.1822509765625 | 78.5287888562515 | 0.172119140625 | 73.100411596327 | 0.1802978515625 | 75.7605175664866 | 0.173583984375 | 71.9332734632299 | 0.177734375 | 78.7097765331263 | 0.1685791015625 | 74.2826761370574 | 0.173583984375 | 76.546361057528 | 0.1759033203125 | 77.3434226893306 | 0.1763916015625 | 78.8033732806207 | 0.1737060546875 | 73.3409144967632 | 0.177978515625 | 76.4804577865017 | 0.1829833984375 |
+| 7000 | 3584000 | 61.9517434630711 | 0.1998291015625 | 65.8841791315944 | 0.1937255859375 | 59.758534754348 | 0.195068359375 | 68.2849480866907 | 0.1873779296875 | 67.5589398554567 | 0.1978759765625 | 66.3181053517182 | 0.1839599609375 | 63.4926142518813 | 0.190673828125 | 61.9170226262262 | 0.1905517578125 | 69.5521157432934 | 0.184814453125 | 71.102816810754 | 0.1768798828125 | 68.2832394942013 | 0.19091796875 | 65.4174348723552 | 0.1873779296875 | 63.2371374279834 | 0.19287109375 | 66.3072686144305 | 0.1904296875 |
+| 8000 | 4096000 | 57.6910232633396 | 0.19580078125 | 63.9510625822693 | 0.192626953125 | 62.0458245490626 | 0.1903076171875 | 63.3404859303251 | 0.1915283203125 | 68.1056529551962 | 0.1993408203125 | 66.3588662355703 | 0.19384765625 | 61.5250891904713 | 0.1998291015625 | 61.9636640921626 | 0.1943359375 | 55.759511792272 | 0.196533203125 | 59.9278142839585 | 0.2008056640625 | 59.5022074715584 | 0.20458984375 | 62.5511721230049 | 0.1959228515625 | 56.9497804269414 | 0.195556640625 | 61.5091127213477 | 0.190185546875 |
+| 9000 | 4608000 | 64.6372463854434 | 0.1925048828125 | 63.3690279269714 | 0.1973876953125 | 58.1806763833258 | 0.203369140625 | 62.998614892463 | 0.19189453125 | 62.1201758106838 | 0.2103271484375 | 64.2889779711819 | 0.190673828125 | 63.9357612486334 | 0.1953125 | 54.3882152248573 | 0.2060546875 | 56.0454761991163 | 0.1982421875 | 71.4737059559803 | 0.1845703125 | 60.5208865396143 | 0.1966552734375 | 59.6999421353521 | 0.1986083984375 | 55.3276841950671 | 0.19970703125 | 56.4976270377332 | 0.2032470703125 |
+| 10000 | 5120000 | 59.762565549819 | 0.2086181640625 | 55.4643195068765 | 0.21044921875 | 53.6647322693999 | 0.2054443359375 | 59.3463334480937 | 0.1990966796875 | 54.4330239809199 | 0.2135009765625 | 55.0087139920183 | 0.203857421875 | 57.1277597552926 | 0.204833984375 | 49.239847599091 | 0.2100830078125 | 53.05835413967 | 0.2088623046875 | 55.1455458758192 | 0.2081298828125 | 53.4885730532123 | 0.207763671875 | 58.6086353300705 | 0.2054443359375 | 54.4873456830018 | 0.197265625 | 57.2894338303807 | 0.2022705078125 |
+
+Using `configs/train_32x08x04.yaml`:
+
+| | | Run 1 | | Run 2 | | Run 3 | | Run 4 | | Run 5 | | Run 6 | | Run 7 | | Run 8 | | Run 9 | | Run 10 | | Run 11 | | Run 12 | | Run 13 | |
+|---------------|-----------|------------------|------------------|------------------|-------------------|------------------|------------------|------------------|-------------------|------------------|------------------|------------------|-------------------|------------------|------------------|------------------|-------------------|------------------|-------------------|------------------|--------------------|------------------|-------------------|------------------|------------------|------------------|-------------------|
+| # Iterations | # Images | FID | CLIP | FID | CLIP | FID | CLIP | FID | CLIP | FID | CLIP | FID | CLIP | FID | CLIP | FID | CLIP | FID | CLIP | FID | CLIP | FID | CLIP | FID | CLIP | FID | CLIP |
+| 500 | 512000 | 340.118601986399 | 0.0345458984375 | 345.832024994334 | 0.04046630859375 | 338.866903071458 | 0.05609130859375 | 426.617864143804 | 0.034149169921875 | 334.872061080559 | 0.0396728515625 | 439.960494194387 | 0.041351318359375 | 376.763327218143 | 0.0313720703125 | 342.790895799001 | 0.036529541015625 | 547.18671253061 | 0.039031982421875 | 351.895078815387 | 0.0277252197265625 | 343.04031513556 | 0.035980224609375 | 328.479876923655 | 0.0419921875 | 321.099746318961 | 0.0372314453125 |
+| 1000 | 1024000 | 223.961567606484 | 0.07330322265625 | 232.917300402437 | 0.057159423828125 | 226.63936355015 | 0.06903076171875 | 230.062675246233 | 0.06732177734375 | 193.949092294986 | 0.084716796875 | 226.70889007593 | 0.07720947265625 | 214.400308061983 | 0.081298828125 | 190.005066809593 | 0.12176513671875 | 205.746223579365 | 0.0953369140625 | 224.200208510141 | 0.068115234375 | 211.447498723363 | 0.0802001953125 | 238.545257722353 | 0.057373046875 | 235.478069634173 | 0.056060791015625 |
+| 1500 | 1536000 | 151.10737628716 | 0.11016845703125 | 150.101113881509 | 0.11566162109375 | 134.800112693793 | 0.12103271484375 | 158.111621138263 | 0.1094970703125 | 166.836301701497 | 0.10260009765625 | 149.410719271839 | 0.12164306640625 | 158.375092356384 | 0.10687255859375 | 147.525295062812 | 0.11199951171875 | 127.031851421568 | 0.114990234375 | 158.962418201722 | 0.11505126953125 | 144.255068287832 | 0.113525390625 | 164.992543548194 | 0.10479736328125 | 141.811754285011 | 0.123779296875 |
+| 2000 | 2048000 | 124.49736091336 | 0.12841796875 | 117.097835152435 | 0.1282958984375 | 108.370974667191 | 0.152587890625 | 133.81907104026 | 0.1280517578125 | 121.229882811864 | 0.1253662109375 | 110.412332150214 | 0.1422119140625 | 100.932808149808 | 0.1419677734375 | 119.578608307018 | 0.133544921875 | 118.33252832813 | 0.133544921875 | 112.354203644191 | 0.134765625 | 117.135931831381 | 0.14013671875 | 124.541325095207 | 0.1229248046875 | 96.3941054845615 | 0.14990234375 |
+| 2500 | 2560000 | 96.9733204361302 | 0.1611328125 | 98.295106421797 | 0.1546630859375 | 89.1922536875995 | 0.1680908203125 | 96.6641498310134 | 0.1571044921875 | 93.8952999670028 | 0.1611328125 | 102.492835331058 | 0.1624755859375 | 89.9220320598469 | 0.1568603515625 | 90.985351002816 | 0.157958984375 | 88.848829608497 | 0.1697998046875 | 94.0665058388316 | 0.15673828125 | 87.9920662633317 | 0.1614990234375 | 96.7143357056873 | 0.148193359375 | 88.7038772396644 | 0.1605224609375 |
+| 3000 | 3072000 | 93.3597904033127 | 0.1673583984375 | 88.0347940287123 | 0.1611328125 | 78.7300792641628 | 0.17431640625 | 73.0465712965473 | 0.1798095703125 | 86.9431468184751 | 0.1639404296875 | 87.7669614240029 | 0.1722412109375 | 88.3675919682403 | 0.171630859375 | 82.6417540012699 | 0.1787109375 | 86.7623871245797 | 0.17138671875 | 79.4699853861112 | 0.170654296875 | 72.7502962971125 | 0.177001953125 | 83.9019958654259 | 0.169677734375 | 85.3945042513396 | 0.1646728515625 |
+| 3500 | 3584000 | 85.1014052323429 | 0.1712646484375 | 79.6465945046047 | 0.17724609375 | 73.179691682848 | 0.1766357421875 | 72.2186363144172 | 0.17529296875 | 82.2117505009101 | 0.16357421875 | 77.0382476200993 | 0.1776123046875 | 76.2638780498772 | 0.1785888671875 | 82.374727338282 | 0.174560546875 | 73.9219581307538 | 0.181640625 | 75.2863090474383 | 0.1729736328125 | 76.0410092939919 | 0.1761474609375 | 73.0823340750989 | 0.1776123046875 | 68.0820682944884 | 0.1822509765625 |
+| 4000 | 4096000 | 67.0934321304885 | 0.18408203125 | 73.6632478577287 | 0.179931640625 | 65.7745415726648 | 0.1893310546875 | 57.5382044168439 | 0.202392578125 | 70.3029895611878 | 0.178955078125 | 63.5395270557031 | 0.1856689453125 | 65.7229965733364 | 0.189697265625 | 63.5817674444506 | 0.194580078125 | 68.9505671827645 | 0.19384765625 | 65.7320230343007 | 0.187255859375 | 60.5971122060028 | 0.2027587890625 | 64.2448592538921 | 0.1851806640625 | 69.7957092926321 | 0.194580078125 |
+| 4500 | 4608000 | 67.7145070185802 | 0.1905517578125 | 61.2478434208809 | 0.1939697265625 | 60.257862121379 | 0.1939697265625 | 56.0916643007543 | 0.195556640625 | 63.112236374844 | 0.1904296875 | 61.7597270678272 | 0.1923828125 | 68.7368374504532 | 0.185791015625 | 61.8152633671387 | 0.197265625 | 60.5004491149336 | 0.2005615234375 | 62.5574345222566 | 0.1959228515625 | 62.923525204544 | 0.1995849609375 | 62.0852679830577 | 0.194580078125 | 61.2746769203445 | 0.199951171875 |
+| 5000 | 5120000 | 62.8701763107206 | 0.194580078125 | 58.8702801515778 | 0.197265625 | 55.8112946371526 | 0.201904296875 | 58.6735053693167 | 0.1932373046875 | 65.8271859495193 | 0.1854248046875 | 56.2070200654653 | 0.1990966796875 | 60.6945204351451 | 0.1988525390625 | 56.4405121502396 | 0.2056884765625 | 61.6967046304992 | 0.19482421875 | 59.4338652228432 | 0.1962890625 | 55.3596165721984 | 0.2091064453125 | 67.6589663911387 | 0.1868896484375 | 53.2199604152237 | 0.2021484375 |
+| 5500 | 5632000 | 61.4585472059858 | 0.1983642578125 | 58.2543886841241 | 0.1995849609375 | 50.8558883425227 | 0.2071533203125 | 51.5331037617053 | 0.2041015625 | 57.3083074689686 | 0.19482421875 | 52.3925368121613 | 0.2099609375 | 57.2835125038846 | 0.200439453125 | 53.6868530762218 | 0.208984375 | 52.2227233111886 | 0.206787109375 | 61.3110736465047 | 0.1966552734375 | 51.6719495239864 | 0.2080078125 | 65.9607437239914 | 0.1888427734375 | 54.6698202033624 | 0.206787109375 |
+
+Using `configs/train_32x08x08.yaml`:
+
+| | | Run 1 | | Run 2 | | Run 3 | | Run 4 | | Run 5 | | Run 6 | | Run 7 | | Run 8 | | Run 9 | | Run 10 | | Run 11 | | Run 12 | | Run 13 | |
+|---------------|----------|------------------|-------------------|------------------|-------------------|------------------|-------------------|------------------|-------------------|------------------|-------------------|------------------|-------------------|------------------|-------------------|------------------|-------------------|------------------|-------------------|------------------|-------------------|------------------|------------------|------------------|-------------------|------------------|--------------------|
+| # Iterations | # Images | FID | CLIP | FID | CLIP | FID | CLIP | FID | CLIP | FID | CLIP | FID | CLIP | FID | CLIP | FID | CLIP | FID | CLIP | FID | CLIP | FID | CLIP | FID | CLIP | FID | CLIP |
+| 250 | 512000 | 391.39655901198 | 0.048370361328125 | 383.955095156789 | 0.045623779296875 | 375.572664708181 | 0.040283203125 | 372.072899672844 | 0.043426513671875 | 372.901651361397 | 0.034698486328125 | 353.118691841468 | 0.04925537109375 | 342.393139376256 | 0.054412841796875 | 311.78591733597 | 0.058197021484375 | 354.174052398155 | 0.051666259765625 | 374.390968238394 | 0.044097900390625 | 458.052244334111 | 0.03900146484375 | 513.581603241846 | 0.041412353515625 | 428.950254618589 | 0.032989501953125 |
+| 500 | 1024000 | 279.186983006922 | 0.045806884765625 | 251.32372064638 | 0.0732421875 | 380.401122648586 | 0.029754638671875 | 265.254012443236 | 0.044219970703125 | 305.165371161357 | 0.043701171875 | 278.005709943431 | 0.041412353515625 | 326.05919184743 | 0.03851318359375 | 285.249900410478 | 0.051910400390625 | 268.047528775676 | 0.06549072265625 | 310.318767239424 | 0.03607177734375 | 268.197820895841 | 0.05615234375 | 276.718579547372 | 0.05157470703125 | 307.082416305895 | 0.0285797119140625 |
+| 750 | 1536000 | 233.095664289828 | 0.08306884765625 | 255.574797209144 | 0.061859130859375 | 206.772130562528 | 0.0755615234375 | 221.581035890856 | 0.08935546875 | 245.255653191484 | 0.06085205078125 | 234.983598562262 | 0.058563232421875 | 250.570216717967 | 0.060821533203125 | 259.234382045129 | 0.07977294921875 | 249.22752060747 | 0.05804443359375 | 239.86518048089 | 0.06610107421875 | 219.41761236929 | 0.07098388671875 | 293.547429078577 | 0.05621337890625 | 240.056633721395 | 0.061737060546875 |
+| 1000 | 2048000 | 183.214691596374 | 0.09649658203125 | 175.121451173839 | 0.0919189453125 | 201.125958260278 | 0.07904052734375 | 194.480361994534 | 0.10003662109375 | 187.902157253089 | 0.095947265625 | 190.636908095649 | 0.093505859375 | 173.836128053516 | 0.11016845703125 | 198.839618312897 | 0.09100341796875 | 191.57493488944 | 0.09283447265625 | 204.476347827797 | 0.08734130859375 | 190.85390278202 | 0.11077880859375 | 177.090880284724 | 0.099853515625 | 172.779613417181 | 0.0966796875 |
+| 1250 | 2560000 | 162.777532690807 | 0.1082763671875 | 147.760423274479 | 0.1129150390625 | 140.85537974249 | 0.12017822265625 | 137.166923984101 | 0.132568359375 | 151.973048193423 | 0.10992431640625 | 148.929026349028 | 0.12249755859375 | 155.895933010689 | 0.12353515625 | 129.288711675567 | 0.137451171875 | 132.886776251458 | 0.125 | 145.39008142017 | 0.12457275390625 | 137.668940540157 | 0.1220703125 | 156.968006561967 | 0.11407470703125 | 142.581989214077 | 0.1314697265625 |
+| 1500 | 3072000 | 118.629682612961 | 0.1375732421875 | 106.535941444941 | 0.14697265625 | 114.290550919822 | 0.138916015625 | 108.378854902796 | 0.149169921875 | 98.9089744692479 | 0.1510009765625 | 103.611036790713 | 0.1544189453125 | 131.286037142359 | 0.12109375 | 103.39877184704 | 0.144775390625 | 117.09422144333 | 0.1375732421875 | 105.877016093857 | 0.1439208984375 | 118.102154582684 | 0.1405029296875 | 103.335599019344 | 0.157470703125 | 113.055096586663 | 0.1434326171875 |
+| 1750 | 3584000 | 89.4331758025722 | 0.162841796875 | 109.546171022477 | 0.1474609375 | 93.1588864410366 | 0.1578369140625 | 83.5085170500488 | 0.1685791015625 | 94.4197067704322 | 0.1588134765625 | 92.6207403376717 | 0.162841796875 | 114.245383656395 | 0.1517333984375 | 87.25296374262 | 0.1685791015625 | 100.343798309662 | 0.1514892578125 | 87.2525657582866 | 0.1697998046875 | 95.2678809672573 | 0.15966796875 | 101.659175216742 | 0.161376953125 | 94.0385862436482 | 0.15966796875 |
+| 2000 | 4096000 | 84.1395342891986 | 0.1728515625 | 86.4941692827467 | 0.171875 | 74.1707863081292 | 0.173583984375 | 87.1637452390254 | 0.1680908203125 | 78.2313975218325 | 0.17431640625 | 83.3830756736107 | 0.172607421875 | 109.383279454652 | 0.1553955078125 | 82.4171094681747 | 0.174560546875 | 87.7106252823979 | 0.1656494140625 | 82.7069497909496 | 0.1707763671875 | 76.3407924798623 | 0.183837890625 | 83.0637053735624 | 0.175048828125 | 94.0479392718985 | 0.1612548828125 |
+| 2250 | 4608000 | 70.582001057074 | 0.18505859375 | 74.8823190242059 | 0.181396484375 | 71.1209497479827 | 0.1898193359375 | 70.0543828502346 | 0.184326171875 | 71.8003769855694 | 0.185302734375 | 64.2308369345424 | 0.1934814453125 | 76.4148499308404 | 0.174560546875 | 73.2825963775967 | 0.1800537109375 | 74.3825432407552 | 0.17919921875 | 75.9861941189723 | 0.1817626953125 | 74.1085235586791 | 0.179443359375 | 74.5693447018001 | 0.175537109375 | 77.284322467554 | 0.17822265625 |
+| 2500 | 5120000 | 65.3838288644697 | 0.1983642578125 | 87.5473119370954 | 0.1688232421875 | 67.7603158381713 | 0.1884765625 | 62.3535779250927 | 0.1925048828125 | 60.8354202339956 | 0.2000732421875 | 78.4751488606406 | 0.1842041015625 | 68.3657367892087 | 0.185791015625 | 65.0020280359325 | 0.1942138671875 | 67.1935556265445 | 0.1871337890625 | 62.4738416640282 | 0.198974609375 | 64.389505383101 | 0.1939697265625 | 77.7334283478027 | 0.1776123046875 | 68.7218599839536 | 0.1864013671875 |
+| 2750 | 5632000 | 59.4060543096938 | 0.1959228515625 | 74.741528787536 | 0.182861328125 | 60.1765143149466 | 0.20068359375 | 58.8862731872101 | 0.197021484375 | 60.4208051048209 | 0.19775390625 | 56.8841221646688 | 0.1993408203125 | 70.8894837982713 | 0.1861572265625 | 71.2384122739707 | 0.1981201171875 | 65.8023543269634 | 0.1947021484375 | 59.4144150742882 | 0.1959228515625 | 58.3607462687932 | 0.204833984375 | 59.8219103754054 | 0.1953125 | 60.7754410616672 | 0.2003173828125 |
+| 3000 | 6144000 | 56.6601142266372 | 0.2022705078125 | 68.2707963466305 | 0.1937255859375 | 53.0671317239887 | 0.2049560546875 | 52.5991947034437 | 0.20751953125 | 56.1652745581564 | 0.2054443359375 | 61.7744352999075 | 0.20458984375 | 56.3335107682475 | 0.20654296875 | 64.8515262336457 | 0.193359375 | 67.671189122359 | 0.1947021484375 | 60.6994318650713 | 0.19970703125 | 63.2553342895905 | 0.198974609375 | 63.5162905880337 | 0.195068359375 | 56.9187600001671 | 0.20361328125 |
+
+
+
+# Rules
+The benchmark rules can be found [here](https://github.com/mlcommons/training_policies/blob/master/training_rules.adoc)
+
# BibTeX
```
@@ -278,4 +339,4 @@ TODO(ahmadki): with RCPs
doi = {10.5281/zenodo.5143773},
url = {https://doi.org/10.5281/zenodo.5143773}
}
-```
\ No newline at end of file
+```
diff --git a/stable_diffusion/configs/train_512_latents.yaml b/stable_diffusion/configs/train_01x08x08.yaml
similarity index 91%
rename from stable_diffusion/configs/train_512_latents.yaml
rename to stable_diffusion/configs/train_01x08x08.yaml
index 207e46ef7..73db364bf 100644
--- a/stable_diffusion/configs/train_512_latents.yaml
+++ b/stable_diffusion/configs/train_01x08x08.yaml
@@ -9,7 +9,7 @@ model:
log_every_t: 200
timesteps: 1000
first_stage_key: npy
- first_stage_type: latents
+ first_stage_type: moments
cond_stage_key: txt
image_size: 64
channels: 4
@@ -38,7 +38,7 @@ model:
enabled: True
inception_weights_url: https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth
cache_dir: /checkpoints/inception
- gt_path: /datasets/coco2014/val2014_512x512_30k_stats.npz
+ gt_path: /datasets/coco2014/val2014_30k_stats.npz
clip:
enabled: True
clip_version: "ViT-H-14"
@@ -47,7 +47,7 @@ model:
scheduler_config:
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
- warm_up_steps: [ 10000 ]
+ warm_up_steps: [ 1000 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
@@ -110,7 +110,7 @@ data:
train:
target: ldm.data.webdatasets.build_dataloader
params:
- urls: /datasets/laion-400m/webdataset-latents-filtered/{00000..00771}.tar
+ urls: /datasets/laion-400m/webdataset-moments-filtered/{00000..00831}.tar
batch_size: 8
shuffle: 1000
partial: False
@@ -138,7 +138,7 @@ lightning:
enable_progress_bar: False
max_epochs: -1
max_steps: 10000000000000
- val_check_interval: 1000 # TODO(ahmadki): final validation interval will be determined with RCPs
+ val_check_interval: 1000
enable_checkpointing: True
num_sanity_val_steps: 0
strategy:
@@ -150,4 +150,4 @@ lightning:
target: lightning.pytorch.callbacks.ModelCheckpoint
params:
save_top_k: -1
- every_n_train_steps: 2000
+ every_n_train_steps: 1000
diff --git a/stable_diffusion/configs/train_32x08x02.yaml b/stable_diffusion/configs/train_32x08x02.yaml
new file mode 100644
index 000000000..aaa4a4681
--- /dev/null
+++ b/stable_diffusion/configs/train_32x08x02.yaml
@@ -0,0 +1,153 @@
+model:
+ base_learning_rate: 1.25e-7
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
+ params:
+ parameterization: "v"
+ linear_start: 0.00085
+ linear_end: 0.0120
+ num_timesteps_cond: 1
+ log_every_t: 200
+ timesteps: 1000
+ first_stage_key: npy
+ first_stage_type: moments
+ cond_stage_key: txt
+ image_size: 64
+ channels: 4
+ cond_stage_trainable: false
+ conditioning_key: crossattn
+ monitor: steps
+ scale_factor: 0.18215
+ use_ema: False
+
+ load_vae: True
+ load_unet: False
+ load_encoder: True
+
+ validation_config:
+ sampler: "ddim" # plms, ddim, dpm
+ steps: 50
+ scale: 8.0
+ ddim_eta: 0.0
+ prompt_key: "caption"
+ image_fname_key: "image_id"
+
+ save_images:
+ enabled: False
+ base_output_dir: "/results/inference"
+ fid:
+ enabled: True
+ inception_weights_url: https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth
+ cache_dir: /checkpoints/inception
+ gt_path: /datasets/coco2014/val2014_30k_stats.npz
+ clip:
+ enabled: True
+ clip_version: "ViT-H-14"
+ cache_dir: /checkpoints/clip
+
+ scheduler_config:
+ target: ldm.lr_scheduler.LambdaLinearScheduler
+ params:
+ warm_up_steps: [ 1000 ]
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
+ f_start: [ 1.e-6 ]
+ f_max: [ 1. ]
+ f_min: [ 1. ]
+
+ unet_config:
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
+ params:
+ use_checkpoint: False # gradient checkpointing
+ use_fp16: True
+ image_size: 32
+ in_channels: 4
+ out_channels: 4
+ model_channels: 320
+ attention_resolutions: [ 4, 2, 1 ]
+ num_res_blocks: 2
+ channel_mult: [ 1, 2, 4, 4 ]
+ num_head_channels: 64 # need to fix for flash-attn
+ use_spatial_transformer: True
+ use_linear_in_transformer: True
+ transformer_depth: 1
+ context_dim: 1024
+ legacy: False
+
+ first_stage_config:
+ target: ldm.models.autoencoder.AutoencoderKL
+ params:
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ double_z: true
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult:
+ - 1
+ - 2
+ - 4
+ - 4
+ num_res_blocks: 2
+ attn_resolutions: []
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
+
+ cond_stage_config:
+ target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
+ params:
+ arch: "ViT-H-14"
+ version: "laion2b_s32b_b79k"
+ freeze: True
+ layer: "penultimate"
+ cache_dir: /checkpoints/clip
+
+data:
+ target: ldm.data.composable_data_module.ComposableDataModule
+ params:
+ train:
+ target: ldm.data.webdatasets.build_dataloader
+ params:
+ urls: /datasets/laion-400m/webdataset-moments-filtered/{00000..00831}.tar
+ batch_size: 2
+ shuffle: 1000
+ partial: False
+ keep_only_keys: ["npy", "txt"]
+ num_workers: 4
+ persistent_workers: True
+
+ validation:
+ target: ldm.data.tsv.build_dataloader
+ params:
+ annotations_file: "/datasets/coco2014/val2014_30k.tsv"
+ keys: ["image_id", "id", "caption"]
+ batch_size: 8
+ shuffle: False
+ num_workers: 1
+
+lightning:
+ trainer:
+ accelerator: 'gpu'
+ num_nodes: 32
+ devices: 8
+ precision: 16
+ logger: False
+ log_every_n_steps: 5
+ enable_progress_bar: False
+ max_epochs: -1
+ max_steps: 10000000000000
+ val_check_interval: 1000
+ enable_checkpointing: True
+ num_sanity_val_steps: 0
+ strategy:
+ target: strategies.DDPStrategy
+ params:
+ find_unused_parameters: False
+
+ modelcheckpoint:
+ target: lightning.pytorch.callbacks.ModelCheckpoint
+ params:
+ save_top_k: -1
+ every_n_train_steps: 1000
diff --git a/stable_diffusion/configs/train_512.yaml b/stable_diffusion/configs/train_32x08x02_raw_images.yaml
similarity index 93%
rename from stable_diffusion/configs/train_512.yaml
rename to stable_diffusion/configs/train_32x08x02_raw_images.yaml
index 47bd44e61..582147325 100644
--- a/stable_diffusion/configs/train_512.yaml
+++ b/stable_diffusion/configs/train_32x08x02_raw_images.yaml
@@ -38,7 +38,7 @@ model:
enabled: True
inception_weights_url: https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth
cache_dir: /checkpoints/inception
- gt_path: /datasets/coco2014/val2014_512x512_30k_stats.npz
+ gt_path: /datasets/coco2014/val2014_30k_stats.npz
clip:
enabled: True
clip_version: "ViT-H-14"
@@ -47,7 +47,7 @@ model:
scheduler_config:
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
- warm_up_steps: [ 10000 ]
+ warm_up_steps: [ 1000 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
@@ -110,8 +110,8 @@ data:
train:
target: ldm.data.webdatasets.build_dataloader
params:
- urls: /datasets/laion-400m/webdataset-filtered/{00000..00771}.tar
- batch_size: 8
+ urls: /datasets/laion-400m/webdataset-filtered/{00000..00831}.tar
+ batch_size: 2
shuffle: 1000
partial: False
decode: pil
@@ -141,7 +141,7 @@ data:
lightning:
trainer:
accelerator: 'gpu'
- num_nodes: 1
+ num_nodes: 32
devices: 8
precision: 16
logger: False
@@ -149,7 +149,7 @@ lightning:
enable_progress_bar: False
max_epochs: -1
max_steps: 10000000000000
- val_check_interval: 1000 # TODO(ahmadki): final validation interval will be determined with RCPs
+ val_check_interval: 1000
enable_checkpointing: True
num_sanity_val_steps: 0
strategy:
@@ -161,4 +161,4 @@ lightning:
target: lightning.pytorch.callbacks.ModelCheckpoint
params:
save_top_k: -1
- every_n_train_steps: 2000
+ every_n_train_steps: 1000
diff --git a/stable_diffusion/configs/train_32x08x04.yaml b/stable_diffusion/configs/train_32x08x04.yaml
new file mode 100644
index 000000000..747aa3d50
--- /dev/null
+++ b/stable_diffusion/configs/train_32x08x04.yaml
@@ -0,0 +1,153 @@
+model:
+ base_learning_rate: 1.25e-7
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
+ params:
+ parameterization: "v"
+ linear_start: 0.00085
+ linear_end: 0.0120
+ num_timesteps_cond: 1
+ log_every_t: 200
+ timesteps: 1000
+ first_stage_key: npy
+ first_stage_type: moments
+ cond_stage_key: txt
+ image_size: 64
+ channels: 4
+ cond_stage_trainable: false
+ conditioning_key: crossattn
+ monitor: steps
+ scale_factor: 0.18215
+ use_ema: False
+
+ load_vae: True
+ load_unet: False
+ load_encoder: True
+
+ validation_config:
+ sampler: "ddim" # plms, ddim, dpm
+ steps: 50
+ scale: 8.0
+ ddim_eta: 0.0
+ prompt_key: "caption"
+ image_fname_key: "image_id"
+
+ save_images:
+ enabled: False
+ base_output_dir: "/results/inference"
+ fid:
+ enabled: True
+ inception_weights_url: https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth
+ cache_dir: /checkpoints/inception
+ gt_path: /datasets/coco2014/val2014_30k_stats.npz
+ clip:
+ enabled: True
+ clip_version: "ViT-H-14"
+ cache_dir: /checkpoints/clip
+
+ scheduler_config:
+ target: ldm.lr_scheduler.LambdaLinearScheduler
+ params:
+ warm_up_steps: [ 1000 ]
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
+ f_start: [ 1.e-6 ]
+ f_max: [ 1. ]
+ f_min: [ 1. ]
+
+ unet_config:
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
+ params:
+ use_checkpoint: False # gradient checkpointing
+ use_fp16: True
+ image_size: 32
+ in_channels: 4
+ out_channels: 4
+ model_channels: 320
+ attention_resolutions: [ 4, 2, 1 ]
+ num_res_blocks: 2
+ channel_mult: [ 1, 2, 4, 4 ]
+ num_head_channels: 64 # need to fix for flash-attn
+ use_spatial_transformer: True
+ use_linear_in_transformer: True
+ transformer_depth: 1
+ context_dim: 1024
+ legacy: False
+
+ first_stage_config:
+ target: ldm.models.autoencoder.AutoencoderKL
+ params:
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ double_z: true
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult:
+ - 1
+ - 2
+ - 4
+ - 4
+ num_res_blocks: 2
+ attn_resolutions: []
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
+
+ cond_stage_config:
+ target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
+ params:
+ arch: "ViT-H-14"
+ version: "laion2b_s32b_b79k"
+ freeze: True
+ layer: "penultimate"
+ cache_dir: /checkpoints/clip
+
+data:
+ target: ldm.data.composable_data_module.ComposableDataModule
+ params:
+ train:
+ target: ldm.data.webdatasets.build_dataloader
+ params:
+ urls: /datasets/laion-400m/webdataset-moments-filtered/{00000..00831}.tar
+ batch_size: 4
+ shuffle: 1000
+ partial: False
+ keep_only_keys: ["npy", "txt"]
+ num_workers: 4
+ persistent_workers: True
+
+ validation:
+ target: ldm.data.tsv.build_dataloader
+ params:
+ annotations_file: "/datasets/coco2014/val2014_30k.tsv"
+ keys: ["image_id", "id", "caption"]
+ batch_size: 8
+ shuffle: False
+ num_workers: 1
+
+lightning:
+ trainer:
+ accelerator: 'gpu'
+ num_nodes: 32
+ devices: 8
+ precision: 16
+ logger: False
+ log_every_n_steps: 5
+ enable_progress_bar: False
+ max_epochs: -1
+ max_steps: 10000000000000
+ val_check_interval: 500
+ enable_checkpointing: True
+ num_sanity_val_steps: 0
+ strategy:
+ target: strategies.DDPStrategy
+ params:
+ find_unused_parameters: False
+
+ modelcheckpoint:
+ target: lightning.pytorch.callbacks.ModelCheckpoint
+ params:
+ save_top_k: -1
+ every_n_train_steps: 500
diff --git a/stable_diffusion/configs/train_32x08x08.yaml b/stable_diffusion/configs/train_32x08x08.yaml
new file mode 100644
index 000000000..166a5deed
--- /dev/null
+++ b/stable_diffusion/configs/train_32x08x08.yaml
@@ -0,0 +1,153 @@
+model:
+ base_learning_rate: 1.25e-7
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
+ params:
+ parameterization: "v"
+ linear_start: 0.00085
+ linear_end: 0.0120
+ num_timesteps_cond: 1
+ log_every_t: 200
+ timesteps: 1000
+ first_stage_key: npy
+ first_stage_type: moments
+ cond_stage_key: txt
+ image_size: 64
+ channels: 4
+ cond_stage_trainable: false
+ conditioning_key: crossattn
+ monitor: steps
+ scale_factor: 0.18215
+ use_ema: False
+
+ load_vae: True
+ load_unet: False
+ load_encoder: True
+
+ validation_config:
+ sampler: "ddim" # plms, ddim, dpm
+ steps: 50
+ scale: 8.0
+ ddim_eta: 0.0
+ prompt_key: "caption"
+ image_fname_key: "image_id"
+
+ save_images:
+ enabled: False
+ base_output_dir: "/results/inference"
+ fid:
+ enabled: True
+ inception_weights_url: https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth
+ cache_dir: /checkpoints/inception
+ gt_path: /datasets/coco2014/val2014_30k_stats.npz
+ clip:
+ enabled: True
+ clip_version: "ViT-H-14"
+ cache_dir: /checkpoints/clip
+
+ scheduler_config:
+ target: ldm.lr_scheduler.LambdaLinearScheduler
+ params:
+ warm_up_steps: [ 1000 ]
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
+ f_start: [ 1.e-6 ]
+ f_max: [ 1. ]
+ f_min: [ 1. ]
+
+ unet_config:
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
+ params:
+ use_checkpoint: False # gradient checkpointing
+ use_fp16: True
+ image_size: 32
+ in_channels: 4
+ out_channels: 4
+ model_channels: 320
+ attention_resolutions: [ 4, 2, 1 ]
+ num_res_blocks: 2
+ channel_mult: [ 1, 2, 4, 4 ]
+ num_head_channels: 64 # need to fix for flash-attn
+ use_spatial_transformer: True
+ use_linear_in_transformer: True
+ transformer_depth: 1
+ context_dim: 1024
+ legacy: False
+
+ first_stage_config:
+ target: ldm.models.autoencoder.AutoencoderKL
+ params:
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ double_z: true
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult:
+ - 1
+ - 2
+ - 4
+ - 4
+ num_res_blocks: 2
+ attn_resolutions: []
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
+
+ cond_stage_config:
+ target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
+ params:
+ arch: "ViT-H-14"
+ version: "laion2b_s32b_b79k"
+ freeze: True
+ layer: "penultimate"
+ cache_dir: /checkpoints/clip
+
+data:
+ target: ldm.data.composable_data_module.ComposableDataModule
+ params:
+ train:
+ target: ldm.data.webdatasets.build_dataloader
+ params:
+ urls: /datasets/laion-400m/webdataset-moments-filtered/{00000..00831}.tar
+ batch_size: 8
+ shuffle: 1000
+ partial: False
+ keep_only_keys: ["npy", "txt"]
+ num_workers: 4
+ persistent_workers: True
+
+ validation:
+ target: ldm.data.tsv.build_dataloader
+ params:
+ annotations_file: "/datasets/coco2014/val2014_30k.tsv"
+ keys: ["image_id", "id", "caption"]
+ batch_size: 8
+ shuffle: False
+ num_workers: 1
+
+lightning:
+ trainer:
+ accelerator: 'gpu'
+ num_nodes: 32
+ devices: 8
+ precision: 16
+ logger: False
+ log_every_n_steps: 5
+ enable_progress_bar: False
+ max_epochs: -1
+ max_steps: 10000000000000
+ val_check_interval: 250
+ enable_checkpointing: True
+ num_sanity_val_steps: 0
+ strategy:
+ target: strategies.DDPStrategy
+ params:
+ find_unused_parameters: False
+
+ modelcheckpoint:
+ target: lightning.pytorch.callbacks.ModelCheckpoint
+ params:
+ save_top_k: -1
+ every_n_train_steps: 250
diff --git a/stable_diffusion/imgs/overview.png b/stable_diffusion/imgs/overview.png
new file mode 100644
index 000000000..d64012ee2
Binary files /dev/null and b/stable_diffusion/imgs/overview.png differ
diff --git a/stable_diffusion/ldm/data/webdatasets.py b/stable_diffusion/ldm/data/webdatasets.py
index 6bfdc5b56..bd1d7d77d 100644
--- a/stable_diffusion/ldm/data/webdatasets.py
+++ b/stable_diffusion/ldm/data/webdatasets.py
@@ -6,6 +6,11 @@
from ldm.util import instantiate_from_config
from ldm.data.utils import instantiate_transforms_from_config, identity, keys_filter
+from PIL import Image
+
+Image.MAX_IMAGE_PIXELS = None
+
+
def build_dataloader(
urls,
batch_size,
diff --git a/stable_diffusion/ldm/models/autoencoder.py b/stable_diffusion/ldm/models/autoencoder.py
index b1bd83778..4b65a12d9 100644
--- a/stable_diffusion/ldm/models/autoencoder.py
+++ b/stable_diffusion/ldm/models/autoencoder.py
@@ -83,9 +83,13 @@ def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.model_ema(self)
- def encode(self, x):
+ def moments(self, x):
h = self.encoder(x)
moments = self.quant_conv(h)
+ return moments
+
+ def encode(self, x):
+ moments = self.moments(x)
posterior = DiagonalGaussianDistribution(moments)
return posterior
diff --git a/stable_diffusion/ldm/models/diffusion/ddpm.py b/stable_diffusion/ldm/models/diffusion/ddpm.py
index bc25cb25c..13da5acdb 100644
--- a/stable_diffusion/ldm/models/diffusion/ddpm.py
+++ b/stable_diffusion/ldm/models/diffusion/ddpm.py
@@ -192,11 +192,6 @@ def __init__(self,
self.validation_clip_scores = []
if validation_config is not None:
- self.validation_save_images = validation_config["save_images"]["enabled"]
- self.validation_run_fid = validation_config["fid"]["enabled"]
- self.validation_run_clip = validation_config["clip"]["enabled"]
-
- if self.validation_save_images or self.validation_run_fid or self.validation_run_clip:
if validation_config["sampler"] == "plms":
self.sampler = PLMSSampler(self)
elif validation_config["sampler"] == "dpm":
@@ -205,13 +200,17 @@ def __init__(self,
self.sampler = DDIMSampler(self)
else:
raise NotImplementedError(f"Sampler {self.sampler} not yet supported")
-
- self.validation_scale = validation_config["scale"]
self.validation_sampler_steps = validation_config["steps"]
+ self.validation_scale = validation_config["scale"]
self.validation_ddim_eta = validation_config["ddim_eta"]
self.prompt_key = validation_config["prompt_key"]
self.image_fname_key = validation_config["image_fname_key"]
+ self.validation_save_images = validation_config["save_images"]["enabled"]
+ self.validation_run_fid = validation_config["fid"]["enabled"]
+ self.validation_run_clip = validation_config["clip"]["enabled"]
+
+ if self.validation_save_images or self.validation_run_fid or self.validation_run_clip:
if self.validation_save_images:
self.validation_base_output_dir = validation_config["save_images"]["base_output_dir"]
@@ -1009,7 +1008,15 @@ def get_input(self,
if bs is not None:
x = x[:bs]
x = x.to(self.device)
- encoder_posterior = self.encode_first_stage(x) if self.first_stage_type == "images" else x
+
+ if self.first_stage_type == "images":
+ encoder_posterior = self.encode_first_stage(x)
+ elif self.first_stage_type == "latents":
+ encoder_posterior = x
+ elif self.first_stage_type == "moments":
+ x = torch.squeeze(x, dim=1)
+ encoder_posterior = DiagonalGaussianDistribution(x)
+
z = self.get_first_stage_encoding(encoder_posterior).detach()
if self.model.conditioning_key is not None and not self.force_null_conditioning:
@@ -1071,6 +1078,10 @@ def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
def encode_first_stage(self, x):
return self.first_stage_model.encode(x)
+ @torch.no_grad()
+ def moments_first_stage(self, x):
+ return self.first_stage_model.moments(x)
+
def shared_step(self, batch, **kwargs):
x, c = self.get_input(batch, self.first_stage_key)
loss = self(x, c)
diff --git a/stable_diffusion/main.py b/stable_diffusion/main.py
index 0bd290e58..d7e984f2e 100644
--- a/stable_diffusion/main.py
+++ b/stable_diffusion/main.py
@@ -136,14 +136,14 @@ def str2bool(v):
parser.add_argument(
"--fid_threshold",
type=int,
- default=None, # TODO(ahmadki): set after finzliaing RCPs
+ default=90,
help="halt training once this FID validation score or a smaller one is achieved."
"if used with --clip_threshold, both metrics need to reach their targets.",
)
parser.add_argument(
"--clip_threshold",
type=int,
- default=None, # TODO(ahmadki): set after finzliaing RCPs
+ default=0.15,
help="halt training once this CLIP validation score or a higher one is achieved."
"if used with --fid_threshold, both metrics need to reach their targets.",
)
@@ -450,7 +450,7 @@ def on_train_epoch_end(self, trainer, pl_module):
mllogger.event(key=mllog_constants.SEED, value=opt.seed)
seed_everything(opt.seed)
- # Intinalize and save configuratioon using teh OmegaConf library.
+ # Intinalize and save configuratioon using the OmegaConf library.
try:
# init and save configs
configs = [OmegaConf.load(cfg) for cfg in opt.base]
diff --git a/stable_diffusion/requirements.txt b/stable_diffusion/requirements.txt
index edc5b8bd3..f4cc1a909 100644
--- a/stable_diffusion/requirements.txt
+++ b/stable_diffusion/requirements.txt
@@ -24,4 +24,4 @@ cloudpathlib==0.13.0
git+https://github.com/facebookresearch/xformers.git@5eb0dbf315d14b5f7b38ac2ff3d8379beca7df9b#egg=xformers
bitsandbytes==0.37.2
# TODO(ahmadki): use github.com:mlcommons/logging.git once the SD PR is merged
-git+https://github.com/ahmadki/logging.git@561b08ef9fa213478d7b78518108c8d5dbf6afc2
+git+https://github.com/mlcommons/logging.git@8405a08bbfc724f8888c419461c02d55a6ac960c
diff --git a/stable_diffusion/scripts/datasets/coco2014-validation-download-prompts.sh b/stable_diffusion/scripts/datasets/coco2014-validation-download-prompts.sh
new file mode 100755
index 000000000..9fd4b01e1
--- /dev/null
+++ b/stable_diffusion/scripts/datasets/coco2014-validation-download-prompts.sh
@@ -0,0 +1,15 @@
+#!/usr/bin/env bash
+
+: "${OUTPUT_DIR:=/datasets/coco2014}"
+
+while [ "$1" != "" ]; do
+ case $1 in
+ -o | --output-dir ) shift
+ OUTPUT_DIR=$1
+ ;;
+ esac
+ shift
+done
+
+mkdir -p ${OUTPUT_DIR}
+wget -O ${OUTPUT_DIR}/val2014_30k.tsv -c "https://cloud.mlcommons.org/index.php/s/training_stable_diffusion/download?path=/datasets/coco2014&files=val2014_30k.tsv"
diff --git a/stable_diffusion/scripts/datasets/coco2014-validation-download-stats.sh b/stable_diffusion/scripts/datasets/coco2014-validation-download-stats.sh
new file mode 100755
index 000000000..bb3222ee3
--- /dev/null
+++ b/stable_diffusion/scripts/datasets/coco2014-validation-download-stats.sh
@@ -0,0 +1,15 @@
+#!/usr/bin/env bash
+
+: "${OUTPUT_DIR:=/datasets/coco2014}"
+
+while [ "$1" != "" ]; do
+ case $1 in
+ -o | --output-dir ) shift
+ OUTPUT_DIR=$1
+ ;;
+ esac
+ shift
+done
+
+mkdir -p ${OUTPUT_DIR}
+wget -O ${OUTPUT_DIR}/val2014_30k_stats.npz -c "https://cloud.mlcommons.org/index.php/s/training_stable_diffusion/download?path=/datasets/coco2014&files=val2014_30k_stats.npz"
diff --git a/stable_diffusion/scripts/datasets/laion-aesthetic-download-dataset.sh b/stable_diffusion/scripts/datasets/laion-aesthetic-download-dataset.sh
deleted file mode 100755
index 9dc8334ea..000000000
--- a/stable_diffusion/scripts/datasets/laion-aesthetic-download-dataset.sh
+++ /dev/null
@@ -1,40 +0,0 @@
-#!/usr/bin/env bash
-
-: "${NPROCS:=16}"
-: "${NTHREADS:=64}"
-: "${METADATA_DIR:=/datasets/laion2B-en-aesthetic/metadata}"
-: "${OUTPUT_DIR:=/datasets/laion2B-en-aesthetic/webdataset}"
-
-while [ "$1" != "" ]; do
- case $1 in
- -j | --processes ) shift
- NPROCS=$1
- ;;
- -t | --threads ) shift
- NTHREADS=$1
- ;;
- -m | --metadata-dir ) shift
- METADATA_DIR=$1
- ;;
- -o | --output-dir ) shift
- OUTPUT_DIR=$1
- ;;
- esac
- shift
-done
-
-mkdir -p ${OUTPUT_DIR}
-
-img2dataset \
- --url_list ${METADATA_DIR} \
- --input_format "parquet" \
- --url_col "URL" \
- --caption_col "TEXT" \
- --output_format webdataset \
- --output_folder ${OUTPUT_DIR} \
- --processes_count ${NPROCS} \
- --thread_count ${NTHREADS} \
- --incremental_mode "incremental" \
- --resize_mode "no" \
- --save_additional_columns '["similarity","hash","punsafe","pwatermark","aesthetic"]' \
- --enable_wandb False
diff --git a/stable_diffusion/scripts/datasets/laion-aesthetic-download-metadata.sh b/stable_diffusion/scripts/datasets/laion-aesthetic-download-metadata.sh
deleted file mode 100755
index 68242e34e..000000000
--- a/stable_diffusion/scripts/datasets/laion-aesthetic-download-metadata.sh
+++ /dev/null
@@ -1,16 +0,0 @@
-#!/usr/bin/env bash
-
-: "${OUTPUT_DIR:=/datasets/laion2B-en-aesthetic/metadata}"
-
-while [ "$1" != "" ]; do
- case $1 in
- -o | --output-dir ) shift
- OUTPUT_DIR=$1
- ;;
- esac
- shift
-done
-
-mkdir -p ${OUTPUT_DIR}
-
-for i in {00000..00127}; do wget -N -P ${OUTPUT_DIR} https://huggingface.co/datasets/laion/laion2B-en-aesthetic/resolve/main/part-$i-9230b837-b1e0-4254-8b88-ed2976e9cee9-c000.snappy.parquet; done
diff --git a/stable_diffusion/scripts/datasets/laion400m-convert-images-to-latents.sh b/stable_diffusion/scripts/datasets/laion400m-convert-images-to-moments.sh
similarity index 100%
rename from stable_diffusion/scripts/datasets/laion400m-convert-images-to-latents.sh
rename to stable_diffusion/scripts/datasets/laion400m-convert-images-to-moments.sh
diff --git a/stable_diffusion/scripts/datasets/laion400m-filtered-download-images.sh b/stable_diffusion/scripts/datasets/laion400m-filtered-download-images.sh
new file mode 100755
index 000000000..335e0a56f
--- /dev/null
+++ b/stable_diffusion/scripts/datasets/laion400m-filtered-download-images.sh
@@ -0,0 +1,22 @@
+#!/usr/bin/env bash
+
+: "${OUTPUT_DIR:=/datasets/laion-400m/webdataset-filtered}"
+
+while [ "$1" != "" ]; do
+ case $1 in
+ -o | --output-dir ) shift
+ OUTPUT_DIR=$1
+ ;;
+ esac
+ shift
+done
+
+mkdir -p ${OUTPUT_DIR}
+cd ${OUTPUT_DIR}
+
+
+for i in {00000..00831}; do wget -O ${OUTPUT_DIR}/${i}.tar -c "https://cloud.mlcommons.org/index.php/s/training_stable_diffusion/download?path=/datasets/laion-400m/images-webdataset-filtered&files=${i}.tar"; done
+
+wget -O ${OUTPUT_DIR}/sha512sums.txt -c "https://cloud.mlcommons.org/index.php/s/training_stable_diffusion/download?path=/datasets/laion-400m/images-webdataset-filtered&files=sha512sums.txt"
+
+sha512sum --quiet -c sha512sums.txt
diff --git a/stable_diffusion/scripts/datasets/laion400m-filtered-download-moments.sh b/stable_diffusion/scripts/datasets/laion400m-filtered-download-moments.sh
new file mode 100755
index 000000000..ebedef98d
--- /dev/null
+++ b/stable_diffusion/scripts/datasets/laion400m-filtered-download-moments.sh
@@ -0,0 +1,22 @@
+#!/usr/bin/env bash
+
+: "${OUTPUT_DIR:=/datasets/laion-400m/webdataset-moments-filtered}"
+
+while [ "$1" != "" ]; do
+ case $1 in
+ -o | --output-dir ) shift
+ OUTPUT_DIR=$1
+ ;;
+ esac
+ shift
+done
+
+mkdir -p ${OUTPUT_DIR}
+cd ${OUTPUT_DIR}
+
+
+for i in {00000..00831}; do wget -O ${OUTPUT_DIR}/${i}.tar -c "https://cloud.mlcommons.org/index.php/s/training_stable_diffusion/download?path=/datasets/laion-400m/moments-webdataset-filtered&files=${i}.tar"; done
+
+wget -O ${OUTPUT_DIR}/sha512sums.txt -c "https://cloud.mlcommons.org/index.php/s/training_stable_diffusion/download?path=/datasets/laion-400m/moments-webdataset-filtered&files=sha512sums.txt"
+
+sha512sum --quiet -c sha512sums.txt
diff --git a/stable_diffusion/scripts/slurm/sbatch.sh b/stable_diffusion/scripts/slurm/sbatch.sh
index 842754cfb..397226a4d 100755
--- a/stable_diffusion/scripts/slurm/sbatch.sh
+++ b/stable_diffusion/scripts/slurm/sbatch.sh
@@ -56,10 +56,6 @@ WORKDIR_MNT=/workdir
# Job config
JOB_NAME=train_${CONFIG_NAME}_${SUFFIX}
-# Laion aesthetic
-LAION2B_EN_AESTHETIC=/datasets/laion2B-en-aesthetic
-LAION2B_EN_AESTHETIC_MOUNT=/datasets/laion2B-en-aesthetic
-
# Laion 400m
LAION_400M=/datasets/laion-400m
LAION_400M_MOUNT=/datasets/laion-400m
@@ -86,7 +82,7 @@ LOG_DIR="${BASE_LOG_DIR}"
mkdir -p ${LOG_DIR}
# Mounts
-MOUNTS="${PWD}:${WORKDIR_MNT},${LAION2B_EN_AESTHETIC}:${LAION2B_EN_AESTHETIC_MOUNT},${LAION_400M}:${LAION_400M_MOUNT},${COCO}:${COCO_MNT},${RESULTS_DIR}:${RESULTS_MNT},${CKPT_DIR}:${CKPT_MOUNT},${HF_HOME_DIR}:${HF_HOME_MOUNT}"
+MOUNTS="${PWD}:${WORKDIR_MNT},${LAION_400M}:${LAION_400M_MOUNT},${COCO}:${COCO_MNT},${RESULTS_DIR}:${RESULTS_MNT},${CKPT_DIR}:${CKPT_MOUNT},${HF_HOME_DIR}:${HF_HOME_MOUNT}"
sbatch \
--account=${ACCOUNT} \
diff --git a/stable_diffusion/webdataset_images2latents.py b/stable_diffusion/webdataset_images2latents.py
index 9a435ca21..b0f3adc87 100644
--- a/stable_diffusion/webdataset_images2latents.py
+++ b/stable_diffusion/webdataset_images2latents.py
@@ -15,6 +15,9 @@
from ldm.util import instantiate_from_config
+Image.MAX_IMAGE_PIXELS = None
+
+
image_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff", ".ico"]
def is_image(filename):
@@ -24,8 +27,8 @@ def is_image(filename):
def process_image(input_image_path, output_tensor_name, image_transforms, model):
original_image = Image.open(input_image_path).convert('RGB')
transformed_img = image_transforms(original_image).float().unsqueeze(0).to(model.device)
- encoded_image = model.encode_first_stage(transformed_img).sample().squeeze(0)
- np.save(output_tensor_name, encoded_image.to("cpu").numpy())
+ moments = model.moments_first_stage(transformed_img)
+ np.save(output_tensor_name, moments.to("cpu").numpy())
def process_tar(input_tar, output_tar, image_transforms, model):