Skip to content

minoring/VAE

Repository files navigation

VAE for celebA, MNIST dataset

Training Variational Autoencoder using celebA and MNIST dataset.

To train celebA dataset, download imgimg_align_celeba.zip and list_eval_partition.txt, unzip, place it under data/celebA/. For training, you can modify hyperparameters e.g. learning rate, latent-dim, batch size, lr decay schedule... (Take a look at parser_utils.py). Learning curve will be saved at "{model}{dataset}.csv" by default (in this case e.g. cnn_celebA.csv). While training, reconstruction results for the test dataset are saved in results/{dataset}/{model} direcotry, (e.g. results/celebA/cnn). Pretrained model is provided in "{model}{dataset}.pt" format (e.g. cnn_celebA.pt).

celebA Dataset

Training with CNN-based Model

python train.py --model cnn --dataset celebA --num-epochs 300 --latent-dim 128

Reconstruction examples

Original image on the top, reconstruction on the bottom.

Generate Samples

Generation (decode) from random Gaussian distribution. python generate_samples.py --model cnn --dataset celebA --saved-path cnn_celebA.pt --latent-dim 128

Plot Learning Curve

python plot_learning_curve.py --loss Loss --learning-curve-csv cnn_celebA.csv
python plot_learning_curve.py --loss Reconstruction_Loss --learning-curve-csv cnn_celebA.csv
python plot_learning_curve.py --loss KL_Loss --learning-curve-csv cnn_celebA.csv
Loss (Reconstruction + KL) Reconstruction Loss KL Loss

(Both training loss and Test loss decrease, I guess you can train more epochs if you want)

Training with FC-based Model

python train.py --model fc --dataset celebA --num-epochs 300 --latent-dim 128

Reconstruction examples

Original image on the top, reconstruction on the bottom.

Generate Samples

Generation (decode) from random Gaussian distribution.

python generate_samples.py --model fc --dataset celebA --saved-path fc_celebA.pt --latent-dim 128

Plot Learning Curve

python plot_learning_curve.py --loss Loss --learning-curve-csv fc_celebA.csv
python plot_learning_curve.py --loss Reconstruction_Loss --learning-curve-csv fc_celebA.csv
python plot_learning_curve.py --loss KL_Loss --learning-curve-csv fc_celebA.csv
Loss (Reconstruction + KL) Reconstruction Loss KL Loss

MNIST Dataset

Training with CNN-based Model

python train.py --model cnn --dataset mnist --num-epochs 1000 --latent-dim 16

Reconstruction examples

Original image on the top, reconstruction on the bottom.

Generate Samples

Generation (decode) from random Gaussian distribution.

python generate_samples.py --model cnn --dataset mnist --saved-path cnn_mnist.pt --latent-dim 16

Plot Learning Curve

python plot_learning_curve.py --loss Loss --learning-curve-csv cnn_mnist.csv
python plot_learning_curve.py --loss Reconstruction_Loss --learning-curve-csv cnn_mnist.csv
python plot_learning_curve.py --loss KL_Loss --learning-curve-csv cnn_mnist.csv
Loss (Reconstruction + KL) Reconstruction Loss KL Loss

Plot Latent Space

Using t-SNE, maps 16 dimensional latent space to 2D.

python plot_latent_space.py --saved-path cnn_mnist.pt --latent-dim 16 --dataset mnist --model cnn

References

Papers

Dataset

About

Variational Autoencoder in PyTorch

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages