Skip to content

Commit

Permalink
Merge branch 'mlpack:master' into gan-mnist
Browse files Browse the repository at this point in the history
  • Loading branch information
swaingotnochill authored Jul 4, 2021
2 parents e4f0597 + a640ca9 commit b354d5a
Show file tree
Hide file tree
Showing 16 changed files with 3,424 additions and 25 deletions.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

77 changes: 77 additions & 0 deletions mnist_vae_cnn/generate_images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""
@file generate_images.py
@author Atharva Khandait
Generates jpg files from csv.
mlpack is free software; you may redistribute it and/or modify it under the
terms of the 3-clause BSD license. You should have received a copy of the
3-clause BSD license along with mlpack. If not, see
http://www.opensource.org/licenses/BSD-3-Clause for more information.
"""

from PIL import Image
import numpy as np

def ImagesFromCSV(filename,
imgShape = (28, 28),
destination = 'samples',
saveIndividual = False):

# Import the data into a numpy matrix.
samples = np.genfromtxt(filename, delimiter = ',', dtype = np.uint8)

# Reshape and save it as an image in the destination.
tempImage = Image.fromarray(np.reshape(samples[:, 0], imgShape), 'L')
if saveIndividual:
tempImage.save(destination + '/sample0.jpg')

# All the images will be concatenated to this for a combined image.
allSamples = tempImage

for i in range(1, samples.shape[1]):
tempImage = np.reshape(samples[:, i], imgShape)

allSamples = np.concatenate((allSamples, tempImage), axis = 1)

tempImage = Image.fromarray(tempImage, 'L')
if saveIndividual:
tempImage.save(destination + '/sample' + str(i) + '.jpg')

tempImage = allSamples
allSamples = Image.fromarray(allSamples, 'L')
allSamples.save(destination + '/allSamples' + '.jpg')

print ('Samples saved in ' + destination + '/.')

return tempImage

# Save posterior samples.
ImagesFromCSV('./samples_csv_files/samples_posterior.csv', destination =
'samples_posterior')

# Save prior samples with individual latent varying.
latentSize = 10
allLatent = ImagesFromCSV('./samples_csv_files/samples_prior_latent0.csv',
destination = 'samples_prior')

for i in range(1, latentSize):
allLatent = np.concatenate((allLatent,
(ImagesFromCSV('./samples_csv_files/samples_prior_latent' + str(i) + '.csv',
destination = 'samples_prior'))), axis = 0)

saved = Image.fromarray(allLatent, 'L')
saved.save('./samples_prior/allLatent.jpg')

# Save prior samples with 2d latent varying.
nofSamples = 20
allLatent = ImagesFromCSV('./samples_csv_files/samples_prior_latent_2d0.csv',
destination = 'latent')

for i in range(1, nofSamples):
allLatent = np.concatenate((allLatent,
(ImagesFromCSV('./samples_csv_files/samples_prior_latent_2d' + str(i) +
'.csv', destination = 'samples_prior'))), axis = 0)

saved = Image.fromarray(allLatent, 'L')
saved.save('./samples_prior/2dLatent.jpg')
Binary file added mnist_vae_cnn/latent/allSamples.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
73 changes: 50 additions & 23 deletions mnist_vae_cnn/mnist_vae_cnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,10 @@ int main()
constexpr int batchSize = 64;
// The step size of the optimizer.
constexpr double stepSize = 0.001;
// The maximum number of possible iteration
constexpr int maxIteration = 0;
// Number of epochs/ cycle
constexpr int epochs = 1;
// Number of cycles
constexpr int cycles = 10;
// Whether to load a model to train.
constexpr bool loadModel = false;
// Whether to save the trained model.
Expand All @@ -60,9 +62,18 @@ int main()
// Entire dataset(without labels) is loaded from a CSV file.
// Each column represents a data point.
arma::mat fullData;
data::Load("../data/mnist_train.csv", fullData, true, false);
data::Load("../data/mnist_train.csv", fullData, true, true);

// Originally on Kaggle dataset CSV file has header, so it's necessary to
// get rid of this row, in Armadillo representation it's the first column.
fullData =
fullData.submat(0, 1, fullData.n_rows -1, fullData.n_cols -1);
fullData /= 255.0;

// Get rid of the labels
fullData =
fullData.submat(1, 0, fullData.n_rows - 1, fullData.n_cols -1);

if (isBinary)
{
fullData = arma::conv_to<arma::mat>::from(
Expand All @@ -75,10 +86,13 @@ int main()

arma::mat train, validation;
data::Split(fullData, validation, train, trainRatio);

// Loss is calculated on train_test data after each cycle.
arma::mat train_test, dump;
data::Split(train, dump, train_test, 0.045);
arma::mat trainTest, dump;
data::Split(train, dump, trainTest, 0.045);

// No of iterations of the optimizer.
int iterPerCycle = (epochs * train.n_cols);

/**
* Model architecture:
Expand All @@ -102,7 +116,7 @@ int main()
* size 5x5, stride = 1, padding = 0) ---> 14x14x16
* 14x14x16 ------------- Leaky ReLU ------------> 14x14x16
* 14x14x16 ---- transposed conv (1 filter of
* size 15x15, stride = 1, padding = 1) -> 28x28x1
* size 15x15, stride = 0, padding = 1) -> 28x28x1
*/

// Creating the VAE model.
Expand Down Expand Up @@ -171,10 +185,13 @@ int main()
0, // Padding width.
0, // Padding height.
10, // Input width.
10); // Input height.
10, // Input height.
14, // Output width.
14); // Output height.

decoder->Add<LeakyReLU<>>();
decoder->Add<TransposedConvolution<>>(16, 1, 15, 15, 1, 1, 1, 1, 14, 14);
decoder->Add<TransposedConvolution<>>
(16, 1, 15, 15, 1, 1, 0, 0, 14, 14, 28, 28);

vaeModel.Add(decoder);
}
Expand All @@ -189,27 +206,37 @@ int main()
0.9, // Exponential decay rate for the first moment estimates.
0.999, // Exponential decay rate for the weighted infinity norm estimates.
1e-8, // Value used to initialise the mean squared gradient parameter.
maxIteration, // Max number of iterations.
iterPerCycle, // Max number of iterations.
1e-8, // Tolerance.
true);

const clock_t beginTime = clock();
// Cycles for monitoring the progress.
for (int i = 0; i < cycles; i++)
{
// Train neural network. If this is the first iteration, weights are
// random, using current values as starting point otherwise.
vaeModel.Train(train,
train,
optimizer,
ens::PrintLoss(),
ens::ProgressBar(),
ens::Report());

// Don't reset optimizer's parameters between cycles.
optimizer.ResetPolicy() = false;

std::cout << "Loss after cycle " << i << " -> " <<
MeanTestLoss<MeanSModel>(vaeModel, trainTest, batchSize) << std::endl;
}

std::cout << "Initial loss -> "
<< MeanTestLoss<MeanSModel>(vaeModel, train_test, 50) << std::endl;

// Train neural network. If this is the first iteration, weights are
// random, using current values as starting point otherwise.
vaeModel.Train(train,
train,
optimizer,
ens::PrintLoss(),
ens::ProgressBar(),
// Stop the training using Early Stop at min loss.
ens::EarlyStopAtMinLoss());
std::cout << "Time taken to train -> " << float(clock() - beginTime) /
CLOCKS_PER_SEC << " seconds" << std::endl;

// Save the model if specified.
if (saveModel)
{
data::Save("vae/saved_models/vaeCNN.bin", "vaeCNN", vaeModel);
data::Save("./saved_models/vaeCNN.bin", "vaeCNN", vaeModel);
std::cout << "Model saved in vae/saved_models." << std::endl;
}
}
Binary file added mnist_vae_cnn/samples_posterior/allSamples.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added mnist_vae_cnn/samples_prior/2dLatent.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added mnist_vae_cnn/samples_prior/allLatent.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added mnist_vae_cnn/samples_prior/allSamples.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
185 changes: 185 additions & 0 deletions mnist_vae_cnn/vae_generate.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
/**
* @file vae_generate.cpp
* @author Atharva Khandait
*
* Generate MNIST using trained VAE model.
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/
#include <mlpack/core.hpp>
#include <mlpack/core/data/split_data.hpp>
#include <mlpack/core/data/save.hpp>

#include <mlpack/methods/ann/ffn.hpp>
#include <mlpack/methods/ann/layer/layer.hpp>
#include <mlpack/methods/ann/init_rules/he_init.hpp>
#include <mlpack/methods/ann/loss_functions/reconstruction_loss.hpp>
#include <mlpack/methods/ann/loss_functions/mean_squared_error.hpp>
#include <mlpack/methods/ann/dists/bernoulli_distribution.hpp>

#include "vae_utils.hpp"

using namespace mlpack;
using namespace mlpack::ann;

// Convenience typedef
typedef FFN<ReconstructionLoss<arma::mat,
arma::mat,
BernoulliDistribution<arma::mat> >,
HeInitialization> ReconModel;

int main()
{
// Whether to load training data.
constexpr bool loadData = true;
// The number of samples to generate.
constexpr size_t nofSamples = 20;
// Whether modelled on binary data.
constexpr bool isBinary = false;
// the latent size of the VAE model.
constexpr size_t latentSize = 20;

arma::mat fullData, train, validation;

if (loadData)
{
data::Load("../data/mnist_train.csv", fullData, true, false);
// Get rid of the header.
fullData =
fullData.submat(0, 1, fullData.n_rows - 1, fullData.n_cols -1);
fullData /= 255.0;
// Get rid of the labels.
fullData =
fullData.submat(1, 0, fullData.n_rows - 1, fullData.n_cols - 1);

if (isBinary)
{
fullData = arma::conv_to<arma::mat>::from(arma::randu<arma::mat>
(fullData.n_rows, fullData.n_cols) <= fullData);
}
else
fullData = (fullData - 0.5) * 2;

data::Split(fullData, validation, train, 0.8);
}

arma::arma_rng::set_seed_random();

// It doesn't matter what type of network we initialize, as we only need to
// forward pass throught it and not initialize weights or take loss.
FFN<> vaeModel;

// Load the trained model.
if (isBinary)
{
data::Load("./saved_models/vaeBinaryMS.xml", "vaeBinaryMS", vaeModel);
vaeModel.Add<SigmoidLayer<> >();
}
else
{
data::Load("./saved_models/vaeCNN.bin", "vaeMS", vaeModel);
}

arma::mat gaussianSamples, outputDists, samples;

/*
* Sampling from the prior.
*/
gaussianSamples = arma::randn<arma::mat>(latentSize, nofSamples);

// Forward pass only through the decoder(and Sigmod layer in case of binary).
vaeModel.Forward(gaussianSamples,
outputDists,
3 /* Index of the decoder */,
3 + (size_t)isBinary /* Index of the last layer */);

GetSample(outputDists, samples, isBinary);
// Save the prior samples as csv.
data::Save("./samples_csv_files/samples_prior.csv", samples, false, false);

/*
* Sampling from the prior by varying all latent variables.
*/
arma::mat gaussianVaried;

for (size_t i = 0; i < latentSize; i++)
{
gaussianSamples = arma::randn<arma::mat>(latentSize, 1);
gaussianVaried = arma::zeros(latentSize, nofSamples);
gaussianVaried.each_col() = gaussianSamples;

for (size_t j = 0; j < nofSamples; j++)
{
gaussianVaried.col(j)(i) = -1.5 + j * (3.0 / nofSamples);
}

// Forward pass only through the decoder
// (and Sigmod layer in case of binary).
vaeModel.Forward(gaussianVaried,
outputDists,
3 /* Index of the decoder */,
3 + (size_t)isBinary /* Index of the last layer */);

GetSample(outputDists, samples, isBinary);
// Save the prior samples as csv.
data::Save(
"./samples_csv_files/samples_prior_latent" + std::to_string(i) + ".csv",
samples,
false,
false);
}

/*
* Sampling from the prior by varying two latent variables in 2d.
*/
size_t latent1 = 3; // Latent variable to be varied vertically.
size_t latent2 = 4; // Latent variable to be varied horizontally.

for (size_t i = 0; i < nofSamples; i++)
{
gaussianVaried = arma::zeros(latentSize, nofSamples);

for (size_t j = 0; j < nofSamples; j++)
{
// Set the vertical variable to a constant value for the outer loop.
gaussianVaried.col(j)(latent1) = 1.5 - i * (3.0 / nofSamples);
// Vary the horizontal variable from -1.5 to 1.5.
gaussianVaried.col(j)(latent2) = -1.5 + j * (3.0 / nofSamples);
}

// Forward pass only through the decoder
// (and Sigmod layer in case of binary).
vaeModel.Forward(gaussianVaried,
outputDists,
3 /* Index of the decoder */,
3 + (size_t)isBinary /* Index of the last layer */);

GetSample(outputDists, samples, isBinary);
// Save the prior samples as csv.
data::Save("./samples_csv_files/samples_prior_latent_2d" + std::to_string(i)
+ ".csv", samples, false, false);
}

/*
* Sampling from the posterior.
*/
if (loadData)
{
// Forward pass through the entire network given an input datapoint.
vaeModel.Forward(validation.cols(0, 19),
outputDists,
1 /* Index of the encoder */,
3 + (size_t)isBinary /* Index of the last layer */);

GetSample(outputDists, samples, isBinary);
// Save the posterior samples as csv.
data::Save(
"./samples_csv_files/samples_posterior.csv",
samples,
false,
false);
}
}
Loading

0 comments on commit b354d5a

Please sign in to comment.