Skip to content

Commit

Permalink
Merge pull request #157 from hello-fri-end/VAE
Browse files Browse the repository at this point in the history
Samples generating code for Convolutional VAE example
  • Loading branch information
kartikdutt18 authored Jun 30, 2021
2 parents 3754e8e + c8adf45 commit e8f7733
Show file tree
Hide file tree
Showing 9 changed files with 300 additions and 16 deletions.
77 changes: 77 additions & 0 deletions mnist_vae_cnn/generate_images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""
@file generate_images.py
@author Atharva Khandait
Generates jpg files from csv.
mlpack is free software; you may redistribute it and/or modify it under the
terms of the 3-clause BSD license. You should have received a copy of the
3-clause BSD license along with mlpack. If not, see
http://www.opensource.org/licenses/BSD-3-Clause for more information.
"""

from PIL import Image
import numpy as np

def ImagesFromCSV(filename,
imgShape = (28, 28),
destination = 'samples',
saveIndividual = False):

# Import the data into a numpy matrix.
samples = np.genfromtxt(filename, delimiter = ',', dtype = np.uint8)

# Reshape and save it as an image in the destination.
tempImage = Image.fromarray(np.reshape(samples[:, 0], imgShape), 'L')
if saveIndividual:
tempImage.save(destination + '/sample0.jpg')

# All the images will be concatenated to this for a combined image.
allSamples = tempImage

for i in range(1, samples.shape[1]):
tempImage = np.reshape(samples[:, i], imgShape)

allSamples = np.concatenate((allSamples, tempImage), axis = 1)

tempImage = Image.fromarray(tempImage, 'L')
if saveIndividual:
tempImage.save(destination + '/sample' + str(i) + '.jpg')

tempImage = allSamples
allSamples = Image.fromarray(allSamples, 'L')
allSamples.save(destination + '/allSamples' + '.jpg')

print ('Samples saved in ' + destination + '/.')

return tempImage

# Save posterior samples.
ImagesFromCSV('./samples_csv_files/samples_posterior.csv', destination =
'samples_posterior')

# Save prior samples with individual latent varying.
latentSize = 10
allLatent = ImagesFromCSV('./samples_csv_files/samples_prior_latent0.csv',
destination = 'samples_prior')

for i in range(1, latentSize):
allLatent = np.concatenate((allLatent,
(ImagesFromCSV('./samples_csv_files/samples_prior_latent' + str(i) + '.csv',
destination = 'samples_prior'))), axis = 0)

saved = Image.fromarray(allLatent, 'L')
saved.save('./samples_prior/allLatent.jpg')

# Save prior samples with 2d latent varying.
nofSamples = 20
allLatent = ImagesFromCSV('./samples_csv_files/samples_prior_latent_2d0.csv',
destination = 'latent')

for i in range(1, nofSamples):
allLatent = np.concatenate((allLatent,
(ImagesFromCSV('./samples_csv_files/samples_prior_latent_2d' + str(i) +
'.csv', destination = 'samples_prior'))), axis = 0)

saved = Image.fromarray(allLatent, 'L')
saved.save('./samples_prior/2dLatent.jpg')
Binary file added mnist_vae_cnn/latent/allSamples.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
50 changes: 36 additions & 14 deletions mnist_vae_cnn/mnist_vae_cnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,10 @@ int main()
constexpr int batchSize = 64;
// The step size of the optimizer.
constexpr double stepSize = 0.001;
// The maximum number of possible iteration
constexpr int maxIteration = 0;
// Number of epochs/ cycle
constexpr int epochs = 1;
// Number of cycles
constexpr int cycles = 10;
// Whether to load a model to train.
constexpr bool loadModel = false;
// Whether to save the trained model.
Expand Down Expand Up @@ -84,6 +86,13 @@ int main()

arma::mat train, validation;
data::Split(fullData, validation, train, trainRatio);

// Loss is calculated on train_test data after each cycle.
arma::mat trainTest, dump;
data::Split(train, dump, trainTest, 0.045);

// No of iterations of the optimizer.
int iterPerCycle = (epochs * train.n_cols);

/**
* Model architecture:
Expand All @@ -107,7 +116,7 @@ int main()
* size 5x5, stride = 1, padding = 0) ---> 14x14x16
* 14x14x16 ------------- Leaky ReLU ------------> 14x14x16
* 14x14x16 ---- transposed conv (1 filter of
* size 15x15, stride = 1, padding = 1) -> 28x28x1
* size 15x15, stride = 0, padding = 1) -> 28x28x1
*/

// Creating the VAE model.
Expand Down Expand Up @@ -197,24 +206,37 @@ int main()
0.9, // Exponential decay rate for the first moment estimates.
0.999, // Exponential decay rate for the weighted infinity norm estimates.
1e-8, // Value used to initialise the mean squared gradient parameter.
maxIteration, // Max number of iterations.
iterPerCycle, // Max number of iterations.
1e-8, // Tolerance.
true);

const clock_t beginTime = clock();
// Cycles for monitoring the progress.
for (int i = 0; i < cycles; i++)
{
// Train neural network. If this is the first iteration, weights are
// random, using current values as starting point otherwise.
vaeModel.Train(train,
train,
optimizer,
ens::PrintLoss(),
ens::ProgressBar(),
ens::Report());

// Don't reset optimizer's parameters between cycles.
optimizer.ResetPolicy() = false;

std::cout << "Loss after cycle " << i << " -> " <<
MeanTestLoss<MeanSModel>(vaeModel, trainTest, batchSize) << std::endl;
}

// Train neural network. If this is the first iteration, weights are
// random, using current values as starting point otherwise.
vaeModel.Train(train,
train,
optimizer,
ens::PrintLoss(),
ens::ProgressBar(),
// Stop the training using Early Stop at min loss.
ens::EarlyStopAtMinLoss());
std::cout << "Time taken to train -> " << float(clock() - beginTime) /
CLOCKS_PER_SEC << " seconds" << std::endl;

// Save the model if specified.
if (saveModel)
{
data::Save("vae/saved_models/vaeCNN.bin", "vaeCNN", vaeModel);
data::Save("./saved_models/vaeCNN.bin", "vaeCNN", vaeModel);
std::cout << "Model saved in vae/saved_models." << std::endl;
}
}
Binary file added mnist_vae_cnn/samples_posterior/allSamples.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added mnist_vae_cnn/samples_prior/2dLatent.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added mnist_vae_cnn/samples_prior/allLatent.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added mnist_vae_cnn/samples_prior/allSamples.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
185 changes: 185 additions & 0 deletions mnist_vae_cnn/vae_generate.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
/**
* @file vae_generate.cpp
* @author Atharva Khandait
*
* Generate MNIST using trained VAE model.
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/
#include <mlpack/core.hpp>
#include <mlpack/core/data/split_data.hpp>
#include <mlpack/core/data/save.hpp>

#include <mlpack/methods/ann/ffn.hpp>
#include <mlpack/methods/ann/layer/layer.hpp>
#include <mlpack/methods/ann/init_rules/he_init.hpp>
#include <mlpack/methods/ann/loss_functions/reconstruction_loss.hpp>
#include <mlpack/methods/ann/loss_functions/mean_squared_error.hpp>
#include <mlpack/methods/ann/dists/bernoulli_distribution.hpp>

#include "vae_utils.hpp"

using namespace mlpack;
using namespace mlpack::ann;

// Convenience typedef
typedef FFN<ReconstructionLoss<arma::mat,
arma::mat,
BernoulliDistribution<arma::mat> >,
HeInitialization> ReconModel;

int main()
{
// Whether to load training data.
constexpr bool loadData = true;
// The number of samples to generate.
constexpr size_t nofSamples = 20;
// Whether modelled on binary data.
constexpr bool isBinary = false;
// the latent size of the VAE model.
constexpr size_t latentSize = 20;

arma::mat fullData, train, validation;

if (loadData)
{
data::Load("../data/mnist_train.csv", fullData, true, false);
// Get rid of the header.
fullData =
fullData.submat(0, 1, fullData.n_rows - 1, fullData.n_cols -1);
fullData /= 255.0;
// Get rid of the labels.
fullData =
fullData.submat(1, 0, fullData.n_rows - 1, fullData.n_cols - 1);

if (isBinary)
{
fullData = arma::conv_to<arma::mat>::from(arma::randu<arma::mat>
(fullData.n_rows, fullData.n_cols) <= fullData);
}
else
fullData = (fullData - 0.5) * 2;

data::Split(fullData, validation, train, 0.8);
}

arma::arma_rng::set_seed_random();

// It doesn't matter what type of network we initialize, as we only need to
// forward pass throught it and not initialize weights or take loss.
FFN<> vaeModel;

// Load the trained model.
if (isBinary)
{
data::Load("./saved_models/vaeBinaryMS.xml", "vaeBinaryMS", vaeModel);
vaeModel.Add<SigmoidLayer<> >();
}
else
{
data::Load("./saved_models/vaeCNN.bin", "vaeMS", vaeModel);
}

arma::mat gaussianSamples, outputDists, samples;

/*
* Sampling from the prior.
*/
gaussianSamples = arma::randn<arma::mat>(latentSize, nofSamples);

// Forward pass only through the decoder(and Sigmod layer in case of binary).
vaeModel.Forward(gaussianSamples,
outputDists,
3 /* Index of the decoder */,
3 + (size_t)isBinary /* Index of the last layer */);

GetSample(outputDists, samples, isBinary);
// Save the prior samples as csv.
data::Save("./samples_csv_files/samples_prior.csv", samples, false, false);

/*
* Sampling from the prior by varying all latent variables.
*/
arma::mat gaussianVaried;

for (size_t i = 0; i < latentSize; i++)
{
gaussianSamples = arma::randn<arma::mat>(latentSize, 1);
gaussianVaried = arma::zeros(latentSize, nofSamples);
gaussianVaried.each_col() = gaussianSamples;

for (size_t j = 0; j < nofSamples; j++)
{
gaussianVaried.col(j)(i) = -1.5 + j * (3.0 / nofSamples);
}

// Forward pass only through the decoder
// (and Sigmod layer in case of binary).
vaeModel.Forward(gaussianVaried,
outputDists,
3 /* Index of the decoder */,
3 + (size_t)isBinary /* Index of the last layer */);

GetSample(outputDists, samples, isBinary);
// Save the prior samples as csv.
data::Save(
"./samples_csv_files/samples_prior_latent" + std::to_string(i) + ".csv",
samples,
false,
false);
}

/*
* Sampling from the prior by varying two latent variables in 2d.
*/
size_t latent1 = 3; // Latent variable to be varied vertically.
size_t latent2 = 4; // Latent variable to be varied horizontally.

for (size_t i = 0; i < nofSamples; i++)
{
gaussianVaried = arma::zeros(latentSize, nofSamples);

for (size_t j = 0; j < nofSamples; j++)
{
// Set the vertical variable to a constant value for the outer loop.
gaussianVaried.col(j)(latent1) = 1.5 - i * (3.0 / nofSamples);
// Vary the horizontal variable from -1.5 to 1.5.
gaussianVaried.col(j)(latent2) = -1.5 + j * (3.0 / nofSamples);
}

// Forward pass only through the decoder
// (and Sigmod layer in case of binary).
vaeModel.Forward(gaussianVaried,
outputDists,
3 /* Index of the decoder */,
3 + (size_t)isBinary /* Index of the last layer */);

GetSample(outputDists, samples, isBinary);
// Save the prior samples as csv.
data::Save("./samples_csv_files/samples_prior_latent_2d" + std::to_string(i)
+ ".csv", samples, false, false);
}

/*
* Sampling from the posterior.
*/
if (loadData)
{
// Forward pass through the entire network given an input datapoint.
vaeModel.Forward(validation.cols(0, 19),
outputDists,
1 /* Index of the encoder */,
3 + (size_t)isBinary /* Index of the last layer */);

GetSample(outputDists, samples, isBinary);
// Save the posterior samples as csv.
data::Save(
"./samples_csv_files/samples_posterior.csv",
samples,
false,
false);
}
}
4 changes: 2 additions & 2 deletions mnist_vae_cnn/vae_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ using namespace mlpack::ann;
// Calculates mean loss over batches.
template<typename NetworkType = FFN<MeanSquaredError<>, HeInitialization>,
typename DataType = arma::mat>
double MeanTestLoss(NetworkType model, DataType testSet, size_t batchSize)
double MeanTestLoss(NetworkType& model, DataType& testSet, size_t batchSize)
{
double loss = 0;
size_t nofPoints = testSet.n_cols;
Expand Down Expand Up @@ -49,7 +49,7 @@ double MeanTestLoss(NetworkType model, DataType testSet, size_t batchSize)
// Sample from the output distribution and post-process the outputs(because
// we pre-processed it before passing it to the model).
template<typename DataType = arma::mat>
void GetSample(DataType input, DataType& samples, bool isBinary)
void GetSample(DataType &input, DataType& samples, bool isBinary)
{
if (isBinary)
{
Expand Down

0 comments on commit e8f7733

Please sign in to comment.