Skip to content

Commit

Permalink
Merge pull request #156 from hello-fri-end/VAE
Browse files Browse the repository at this point in the history
Fixed segfault of convolutional vae
  • Loading branch information
kartikdutt18 authored Jun 11, 2021
2 parents a465236 + 7e3b620 commit 0139e80
Showing 1 changed file with 15 additions and 10 deletions.
25 changes: 15 additions & 10 deletions mnist_vae_cnn/mnist_vae_cnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,18 @@ int main()
// Entire dataset(without labels) is loaded from a CSV file.
// Each column represents a data point.
arma::mat fullData;
data::Load("../data/mnist_train.csv", fullData, true, false);
data::Load("../data/mnist_train.csv", fullData, true, true);

// Originally on Kaggle dataset CSV file has header, so it's necessary to
// get rid of this row, in Armadillo representation it's the first column.
fullData =
fullData.submat(0, 1, fullData.n_rows -1, fullData.n_cols -1);
fullData /= 255.0;

// Get rid of the labels
fullData =
fullData.submat(1, 0, fullData.n_rows - 1, fullData.n_cols -1);

if (isBinary)
{
fullData = arma::conv_to<arma::mat>::from(
Expand All @@ -76,10 +85,6 @@ int main()
arma::mat train, validation;
data::Split(fullData, validation, train, trainRatio);

// Loss is calculated on train_test data after each cycle.
arma::mat train_test, dump;
data::Split(train, dump, train_test, 0.045);

/**
* Model architecture:
*
Expand Down Expand Up @@ -171,10 +176,13 @@ int main()
0, // Padding width.
0, // Padding height.
10, // Input width.
10); // Input height.
10, // Input height.
14, // Output width.
14); // Output height.

decoder->Add<LeakyReLU<>>();
decoder->Add<TransposedConvolution<>>(16, 1, 15, 15, 1, 1, 1, 1, 14, 14);
decoder->Add<TransposedConvolution<>>
(16, 1, 15, 15, 1, 1, 0, 0, 14, 14, 28, 28);

vaeModel.Add(decoder);
}
Expand All @@ -193,9 +201,6 @@ int main()
1e-8, // Tolerance.
true);

std::cout << "Initial loss -> "
<< MeanTestLoss<MeanSModel>(vaeModel, train_test, 50) << std::endl;

// Train neural network. If this is the first iteration, weights are
// random, using current values as starting point otherwise.
vaeModel.Train(train,
Expand Down

0 comments on commit 0139e80

Please sign in to comment.