This software project accompanies the research paper TRACT: Denoising Diffusion Models with Transitive Closure Time-Distillation
@article{berthelot2023tract,
title={TRACT: Denoising Diffusion Models with Transitive Closure Time-Distillation},
author={Berthelot, David and Autef, Arnaud and Lin, Jierui and Yap, Dian Ang and Zhai, Shuangfei and Hu, Siyuan and Zheng, Daniel and Talbott, Walter and Gu, Eric},
journal={arXiv preprint arXiv:2303.04248},
year={2023}
}
Git clone with --recurse-submodules
to initialize EDM submodule properly.
Setup environment variables:
export ML_DATA=~/Data/DDPM-Images
export PYTHONPATH=$PYTHONPATH:.
Then run
sudo apt install python3.8-dev python3.8-venv python3-dev -y
Set up a virtualenv
python3.8 -m venv ~/tract_venv
source ~/tract_venv/bin/activate
or via pyenv
pyenv install 3.8.0
pyenv virtualenv 3.8.0 tract_venv
pyenv local tract_venv
then upgrade pip
pip install --upgrade pip
Install pip pkgs
pip install -r requirements.txt -f https://download.pytorch.org/whl/torch_stable.html
Please follow README to setup datasets.
Please follow README to setup teacher checkpoints.
For running with NVIDIA's Elucidated model (EDM), ensure the edm/
submodule has been initialized properly.
Please follow the Real activation statistics
section in README in order to compute and save the real activation statistics to be used in FID evaluation.
The below commands will reproduce results from our paper when run on a cluster of 8 NVIDIA A100 or V100 GPUs.
Example: Run TC distillation on Cifar10 using distillation time schedule: 1024, 32, 1.
python tc_distill.py --dataset=cifar10 --time_schedule=1024,32,1 --fid_len=50000 --report_fid_len=8M --report_img_len=8M --train_len=96M
Example: Run TC distillation on Cifar10 using EDM teacher
python tc_distill_edm.py --dataset=cifar10 --time_schedule=40,1 --fid_len=50000 --report_fid_len=8M --report_img_len=8M --train_len=96M --batch=512
Getting help
python tc_distill.py --help
Tensorboard outputs are generated in a dir like e/DATASET/MODEL/EXP_NAME/tb/
. For example, you can start tensorboard to view metrics like
tensorboard --logdir e/cifar10/EluDDIM05TCMultiStepx0\(EluUNet\)/[email protected]_batch@[email protected][email protected][email protected]_lr_warmup@None_res@[email protected]_time_schedule@40,1_timesteps@40/tb/ --bind_all
Please follow README to run FID evaluation.