Skip to content

Keras implementation of a Variational Auto Encoder with a Concrete Latent Distribution

Notifications You must be signed in to change notification settings

SandorSeres/vae-concrete

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

10 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Variational Auto Encoder with Concrete Latent Distribution

Keras implementation of a Variational Auto Encoder with a Concrete latent distribution. See Auto-Encoding Variational Bayes by Kingma and Welling and The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables by Maddison, Mnih and Teh or Categorical Reparameterization with Gumbel-Softmax by Jang, Gu and Poole.

Examples

Samples from a regular VAE

VAE with concrete latent distribution. Each column of the image corresponds to one of the categories of the latent concrete distribution.

Usage

Traditional VAE with a 2 dimensional latent distribution

>>> from vae_concrete import VAE
>>> model = VAE(latent_cont_dim=2)
>>> model.fit(x_train, num_epochs=20)
>>> model.plot()

You should see start seeing good results after ~5 epochs. The loss should approach ~140 upon convergence. Occasionally the optimization gets stuck in a poor local minimum and stays around ~205. In that case it is best to just restart the optimization.

VAE with 2 continuous variables and a 10 dimensional discrete distribution

>>> model = VAE(latent_cont_dim=2, latent_disc_dim=10)
>>> model.fit(x_train, num_epochs=10)
>>> model.plot()

This takes ~10 epochs to start seeing good results. Loss should go down to ~125.

Dependencies

  • keras
  • tensorflow (only tested on tensorflow backend)
  • plotly

Acknowledgements

Code was inspired by the Keras VAE implementation (plotting functionality was also borrowed and modified from this example)

About

Keras implementation of a Variational Auto Encoder with a Concrete Latent Distribution

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%