diff --git a/generating_hand_written_digits_mnist_with_gan/generating_handwritten_digits_mnist_with_gan.ipynb b/generating_hand_written_digits_mnist_with_gan/generating_handwritten_digits_mnist_with_gan.ipynb index 1a335bab..4516c038 100644 --- a/generating_hand_written_digits_mnist_with_gan/generating_handwritten_digits_mnist_with_gan.ipynb +++ b/generating_hand_written_digits_mnist_with_gan/generating_handwritten_digits_mnist_with_gan.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 34, + "execution_count": 1, "id": "c23013fc", "metadata": {}, "outputs": [], @@ -12,36 +12,63 @@ "#include\n", "#include\n", "#include\n", - "#include" + "#include\n", + "#include" ] }, { "cell_type": "code", "execution_count": 2, - "id": "dccd0b37", + "id": "ad7ab4c6", "metadata": {}, "outputs": [], "source": [ - "#include" + "using namespace mlpack;\n" ] }, { "cell_type": "code", - "execution_count": 19, - "id": "ad7ab4c6", + "execution_count": 3, + "id": "0c8af9e2", "metadata": {}, "outputs": [], "source": [ - "using namespace mlpack;\n", - "using namespace mlpack::ann;\n", - "using namespace arma;\n", - "using namespace std;\n", - "using namespace ens;" + "using namespace mlpack::ann;" ] }, { "cell_type": "code", "execution_count": 4, + "id": "ff1b2e2f", + "metadata": {}, + "outputs": [], + "source": [ + "using namespace arma;" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "4834841f", + "metadata": {}, + "outputs": [], + "source": [ + "using namespace std;" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "d1dc9cb6", + "metadata": {}, + "outputs": [], + "source": [ + "using namespace ens;" + ] + }, + { + "cell_type": "code", + "execution_count": 7, "id": "7ed7a903", "metadata": {}, "outputs": [], @@ -57,7 +84,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 8, "id": "131fe748", "metadata": {}, "outputs": [], @@ -70,7 +97,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 9, "id": "697a8d63", "metadata": {}, "outputs": [], @@ -80,7 +107,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "id": "64ee4e72", "metadata": {}, "outputs": [], @@ -110,7 +137,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 13, "id": "a6a484e7", "metadata": {}, "outputs": [], @@ -121,7 +148,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 14, "id": "db267de6", "metadata": {}, "outputs": [], @@ -132,254 +159,176 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 15, "id": "7eb2907b", "metadata": {}, "outputs": [], "source": [ - "using namespace mlpack::ann;\n", "FFN, RandomInitialization> model;" ] }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 16, "id": "ffc66454", "metadata": {}, + "outputs": [], + "source": [ + "model.Add> (1,6,5,5,1,1,0,0,28,28);\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "03d0646a", + "metadata": {}, + "outputs": [], + "source": [ + "model.Add>();" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "2fff115e", + "metadata": {}, + "outputs": [], + "source": [ + "model.Add>(2,\n", + " 2,\n", + " 2,\n", + " 2,\n", + " true);" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "462e4ce2", + "metadata": {}, + "outputs": [], + "source": [ + "model.Add>(6,\n", + " 16,\n", + " 5,\n", + " 5,\n", + " 1,\n", + " 1,\n", + " 0,\n", + " 0,\n", + " 12,\n", + " 12);" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "73033233", + "metadata": {}, + "outputs": [], + "source": [ + "model.Add>();" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "ee5d88cc", + "metadata": {}, + "outputs": [], + "source": [ + "model.Add>();" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "7859b83e", + "metadata": {}, + "outputs": [], + "source": [ + "model.Add>(2,2,2,2,true);" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "f423475d", + "metadata": {}, + "outputs": [], + "source": [ + "model.Add>(16*4*4, 10);" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "369063ba", + "metadata": {}, + "outputs": [], + "source": [ + "model.Add>();" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "17dfcdef", + "metadata": {}, + "outputs": [], + "source": [ + "ens::Adam optimizer(STEP_SIZE,\n", + " BATCH_SIZE,\n", + " 0.9,\n", + " 0.999,\n", + " 1e-8,\n", + " MAX_ITERATIONS,\n", + " 1e-8,\n", + " true);" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "c52947c7", + "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "In file included from input_line_7:5:\n", - "\u001b[1m/home/viole/anaconda3/envs/notebook/include/mlpack/methods/ann/ffn.hpp:290:36: \u001b[0m\u001b[0;1;31merror: \u001b[0m\u001b[1mno member named 'push_back' in\n", - " 'std::vector,\n", - " arma::Mat > *,\n", - " mlpack::ann::AdaptiveMeanPooling,\n", - " arma::Mat > *, mlpack::ann::Add,\n", - " arma::Mat > *, mlpack::ann::AddMerge,\n", - " arma::Mat> *, mlpack::ann::AlphaDropout,\n", - " arma::Mat > *,\n", - " mlpack::ann::AtrousConvolution,\n", - " mlpack::ann::NaiveConvolution,\n", - " mlpack::ann::NaiveConvolution,\n", - " arma::Mat, arma::Mat > *,\n", - " mlpack::ann::BaseLayer, arma::Mat > *,\n", - " mlpack::ann::BaseLayer, arma::Mat > *,\n", - " mlpack::ann::BaseLayer,\n", - " arma::Mat > *,\n", - " mlpack::ann::BaseLayer, arma::Mat > *,\n", - " mlpack::ann::BaseLayer, arma::Mat > *,\n", - " mlpack::ann::BatchNorm, arma::Mat > *,\n", - " mlpack::ann::BilinearInterpolation,\n", - " arma::Mat > *, mlpack::ann::CELU,\n", - " arma::Mat > *, mlpack::ann::Concat,\n", - " arma::Mat> *, mlpack::ann::Concatenate,\n", - " arma::Mat > *,\n", - " mlpack::ann::ConcatPerformance,\n", - " arma::Mat >, arma::Mat, arma::Mat > *,\n", - " mlpack::ann::Constant, arma::Mat > *,\n", - " mlpack::ann::Convolution,\n", - " mlpack::ann::NaiveConvolution,\n", - " mlpack::ann::NaiveConvolution,\n", - " arma::Mat, arma::Mat > *,\n", - " mlpack::ann::CReLU, arma::Mat > *,\n", - " mlpack::ann::DropConnect, arma::Mat > *,\n", - " mlpack::ann::Dropout, arma::Mat > *,\n", - " mlpack::ann::ELU, arma::Mat > *,\n", - " mlpack::ann::FastLSTM, arma::Mat > *,\n", - " mlpack::ann::FlexibleReLU, arma::Mat > *,\n", - " mlpack::ann::GRU, arma::Mat > *,\n", - " mlpack::ann::HardTanH, arma::Mat > *,\n", - " mlpack::ann::Join, arma::Mat > *,\n", - " mlpack::ann::LayerNorm, arma::Mat > *,\n", - " mlpack::ann::LeakyReLU, arma::Mat > *,\n", - " mlpack::ann::Linear, arma::Mat,\n", - " mlpack::ann::NoRegularizer> *,\n", - " mlpack::ann::LinearNoBias, arma::Mat,\n", - " mlpack::ann::NoRegularizer> *,\n", - " mlpack::ann::LogSoftMax, arma::Mat > *,\n", - " mlpack::ann::Lookup, arma::Mat > *,\n", - " mlpack::ann::LSTM, arma::Mat > *,\n", - " mlpack::ann::MaxPooling, arma::Mat > *,\n", - " mlpack::ann::MeanPooling, arma::Mat > *,\n", - " mlpack::ann::MiniBatchDiscrimination,\n", - " arma::Mat > *,\n", - " mlpack::ann::MultiplyConstant, arma::Mat >\n", - " *, mlpack::ann::MultiplyMerge, arma::Mat>\n", - " *, mlpack::ann::NegativeLogLikelihood,\n", - " arma::Mat > *, mlpack::ann::NoisyLinear,\n", - " arma::Mat > *, mlpack::ann::Padding,\n", - " arma::Mat > *, mlpack::ann::PReLU,\n", - " arma::Mat > *, mlpack::ann::Softmax,\n", - " arma::Mat > *,\n", - " mlpack::ann::SpatialDropout, arma::Mat >\n", - " *,\n", - " mlpack::ann::TransposedConvolution,\n", - " mlpack::ann::NaiveConvolution,\n", - " mlpack::ann::NaiveConvolution,\n", - " arma::Mat, arma::Mat > *,\n", - " mlpack::ann::WeightNorm, arma::Mat> *,\n", - " boost::variant,\n", - " arma::Mat, mlpack::ann::NoRegularizer> *,\n", - " mlpack::ann::Glimpse, arma::Mat > *,\n", - " mlpack::ann::Highway, arma::Mat> *,\n", - " mlpack::ann::MultiheadAttention,\n", - " arma::Mat, mlpack::ann::NoRegularizer> *,\n", - " mlpack::ann::Recurrent, arma::Mat> *,\n", - " mlpack::ann::RecurrentAttention, arma::Mat\n", - " > *, mlpack::ann::ReinforceNormal,\n", - " arma::Mat > *,\n", - " mlpack::ann::Reparametrization, arma::Mat\n", - " > *, mlpack::ann::Select, arma::Mat > *,\n", - " mlpack::ann::Sequential, arma::Mat, false>\n", - " *, mlpack::ann::Sequential, arma::Mat,\n", - " true> *, mlpack::ann::Subview, arma::Mat >\n", - " *, mlpack::ann::VRClassReward, arma::Mat >\n", - " *, mlpack::ann::VirtualBatchNorm,\n", - " arma::Mat > *, mlpack::ann::RBF,\n", - " arma::Mat, mlpack::ann::GaussianFunction> *,\n", - " mlpack::ann::BaseLayer, arma::Mat > *,\n", - " mlpack::ann::PositionalEncoding, arma::Mat\n", - " > *> >,\n", - " std::allocator,\n", - " arma::Mat > *,\n", - " mlpack::ann::AdaptiveMeanPooling,\n", - " arma::Mat > *, mlpack::ann::Add,\n", - " arma::Mat > *, mlpack::ann::AddMerge,\n", - " arma::Mat> *, mlpack::ann::AlphaDropout,\n", - " arma::Mat > *,\n", - " mlpack::ann::AtrousConvolution,\n", - " mlpack::ann::NaiveConvolution,\n", - " mlpack::ann::NaiveConvolution,\n", - " arma::Mat, arma::Mat > *,\n", - " mlpack::ann::BaseLayer, arma::Mat > *,\n", - " mlpack::ann::BaseLayer, arma::Mat > *,\n", - " mlpack::ann::BaseLayer,\n", - " arma::Mat > *,\n", - " mlpack::ann::BaseLayer, arma::Mat > *,\n", - " mlpack::ann::BaseLayer, arma::Mat > *,\n", - " mlpack::ann::BatchNorm, arma::Mat > *,\n", - " mlpack::ann::BilinearInterpolation,\n", - " arma::Mat > *, mlpack::ann::CELU,\n", - " arma::Mat > *, mlpack::ann::Concat,\n", - " arma::Mat> *, mlpack::ann::Concatenate,\n", - " arma::Mat > *,\n", - " mlpack::ann::ConcatPerformance,\n", - " arma::Mat >, arma::Mat, arma::Mat > *,\n", - " mlpack::ann::Constant, arma::Mat > *,\n", - " mlpack::ann::Convolution,\n", - " mlpack::ann::NaiveConvolution,\n", - " mlpack::ann::NaiveConvolution,\n", - " arma::Mat, arma::Mat > *,\n" + "\n", + "error: Mat::operator(): index out of bounds\n" ] }, { - "name": "stderr", - "output_type": "stream", - "text": [ - " mlpack::ann::CReLU, arma::Mat > *,\n", - " mlpack::ann::DropConnect, arma::Mat > *,\n", - " mlpack::ann::Dropout, arma::Mat > *,\n", - " mlpack::ann::ELU, arma::Mat > *,\n", - " mlpack::ann::FastLSTM, arma::Mat > *,\n", - " mlpack::ann::FlexibleReLU, arma::Mat > *,\n", - " mlpack::ann::GRU, arma::Mat > *,\n", - " mlpack::ann::HardTanH, arma::Mat > *,\n", - " mlpack::ann::Join, arma::Mat > *,\n", - " mlpack::ann::LayerNorm, arma::Mat > *,\n", - " mlpack::ann::LeakyReLU, arma::Mat > *,\n", - " mlpack::ann::Linear, arma::Mat,\n", - " mlpack::ann::NoRegularizer> *,\n", - " mlpack::ann::LinearNoBias, arma::Mat,\n", - " mlpack::ann::NoRegularizer> *,\n", - " mlpack::ann::LogSoftMax, arma::Mat > *,\n", - " mlpack::ann::Lookup, arma::Mat > *,\n", - " mlpack::ann::LSTM, arma::Mat > *,\n", - " mlpack::ann::MaxPooling, arma::Mat > *,\n", - " mlpack::ann::MeanPooling, arma::Mat > *,\n", - " mlpack::ann::MiniBatchDiscrimination,\n", - " arma::Mat > *,\n", - " mlpack::ann::MultiplyConstant, arma::Mat >\n", - " *, mlpack::ann::MultiplyMerge, arma::Mat>\n", - " *, mlpack::ann::NegativeLogLikelihood,\n", - " arma::Mat > *, mlpack::ann::NoisyLinear,\n", - " arma::Mat > *, mlpack::ann::Padding,\n", - " arma::Mat > *, mlpack::ann::PReLU,\n", - " arma::Mat > *, mlpack::ann::Softmax,\n", - " arma::Mat > *,\n", - " mlpack::ann::SpatialDropout, arma::Mat >\n", - " *,\n", - " mlpack::ann::TransposedConvolution,\n", - " mlpack::ann::NaiveConvolution,\n", - " mlpack::ann::NaiveConvolution,\n", - " arma::Mat, arma::Mat > *,\n", - " mlpack::ann::WeightNorm, arma::Mat> *,\n", - " boost::variant,\n", - " arma::Mat, mlpack::ann::NoRegularizer> *,\n", - " mlpack::ann::Glimpse, arma::Mat > *,\n", - " mlpack::ann::Highway, arma::Mat> *,\n", - " mlpack::ann::MultiheadAttention,\n", - " arma::Mat, mlpack::ann::NoRegularizer> *,\n", - " mlpack::ann::Recurrent, arma::Mat> *,\n", - " mlpack::ann::RecurrentAttention, arma::Mat\n", - " > *, mlpack::ann::ReinforceNormal,\n", - " arma::Mat > *,\n", - " mlpack::ann::Reparametrization, arma::Mat\n", - " > *, mlpack::ann::Select, arma::Mat > *,\n", - " mlpack::ann::Sequential, arma::Mat, false>\n", - " *, mlpack::ann::Sequential, arma::Mat,\n", - " true> *, mlpack::ann::Subview, arma::Mat >\n", - " *, mlpack::ann::VRClassReward, arma::Mat >\n", - " *, mlpack::ann::VirtualBatchNorm,\n", - " arma::Mat > *, mlpack::ann::RBF,\n", - " arma::Mat, mlpack::ann::GaussianFunction> *,\n", - " mlpack::ann::BaseLayer, arma::Mat > *,\n", - " mlpack::ann::PositionalEncoding, arma::Mat\n", - " > *> > > >'\u001b[0m\n", - " void Add(Args... args) { network.push_back(new LayerType(args...)); }\n", - "\u001b[0;1;32m ~~~~~~~ ^\n", - "\u001b[0m\u001b[1minput_line_46:2:8: \u001b[0m\u001b[0;1;30mnote: \u001b[0min instantiation of function template specialization\n", - " 'mlpack::ann::FFN,\n", - " arma::Mat >,\n", - " mlpack::ann::RandomInitialization>::Add,\n", - " mlpack::ann::NaiveConvolution,\n", - " mlpack::ann::NaiveConvolution,\n", - " arma::Mat, arma::Mat >, int, int, int, int, int,\n", - " int, int, int, int, int>' requested here\u001b[0m\n", - " model.Add> (1,6,5,5,1,1,0,0,28,28);\n", - "\u001b[0;1;32m ^\n", - "\u001b[0m" - ] - }, - { - "ename": "Interpreter Error", - "evalue": "", + "ename": "Standard Exception", + "evalue": "Mat::operator(): index out of bounds", "output_type": "error", "traceback": [ - "Interpreter Error: " + "Standard Exception: Mat::operator(): index out of bounds" ] } ], "source": [ - "model.Add> (1,6,5,5,1,1,0,0,28,28);\n", - " " + "model.Train(trainX,\n", + " trainY,\n", + " optimizer,\n", + " ens::PrintLoss(),\n", + " ens::ProgressBar(),\n", + " ens::EarlyStopAtMinLoss(\n", + " [&](const arma::mat&){\n", + " double validationLoss = model.Evaluate(validX, validY);\n", + " std::cout << \"Validation Loss\" << validationLoss << \".\" << \"\\n\";\n", + " return validationLoss;\n", + " }))" ] }, { "cell_type": "code", "execution_count": null, - "id": "03d0646a", + "id": "d6060587", "metadata": {}, "outputs": [], "source": [] diff --git a/generating_hand_written_digits_mnist_with_gan/testNotebook.ipynb b/generating_hand_written_digits_mnist_with_gan/testNotebook.ipynb new file mode 100644 index 00000000..01a866d4 --- /dev/null +++ b/generating_hand_written_digits_mnist_with_gan/testNotebook.ipynb @@ -0,0 +1,288 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "id": "4ffd91b1", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[1minput_line_9:11:1: \u001b[0m\u001b[0;1;31merror: \u001b[0m\u001b[1mfunction definition is not allowed here\u001b[0m\n", + "{\n", + "\u001b[0;1;32m^\n", + "\u001b[0m\u001b[1minput_line_9:20:1: \u001b[0m\u001b[0;1;31merror: \u001b[0m\u001b[1mfunction definition is not allowed here\u001b[0m\n", + "{\n", + "\u001b[0;1;32m^\n", + "\u001b[0m" + ] + }, + { + "ename": "Interpreter Error", + "evalue": "", + "output_type": "error", + "traceback": [ + "Interpreter Error: " + ] + } + ], + "source": [ + "\n", + "/**\n", + " * An example of using Convolutional Neural Network (CNN) for\n", + " * solving Digit Recognizer problem from Kaggle website.\n", + " *\n", + " * The full description of a problem as well as datasets for training\n", + " * and testing are available here https://www.kaggle.com/c/digit-recognizer\n", + " *\n", + " * mlpack is free software; you may redistribute it and/or modify it under the\n", + " * terms of the 3-clause BSD license. You should have received a copy of the\n", + " * 3-clause BSD license along with mlpack. If not, see\n", + " * http://www.opensource.org/licenses/BSD-3-Clause for more information.\n", + " *\n", + " * @author Daivik Nema\n", + " */\n", + "\n", + "#include \n", + "#include \n", + "\n", + "#include \n", + "#include \n", + "\n", + "#include \n", + "\n", + "#if ((ENS_VERSION_MAJOR < 2) || ((ENS_VERSION_MAJOR == 2) && (ENS_VERSION_MINOR < 13)))\n", + " #error \"need ensmallen version 2.13.0 or later\"\n", + "#endif\n", + "\n", + "using namespace mlpack;\n", + "using namespace mlpack::ann;\n", + "\n", + "using namespace arma;\n", + "using namespace std;\n", + "\n", + "using namespace ens;\n", + "\n", + "arma::Row getLabels(arma::mat predOut)\n", + "{\n", + " arma::Row predLabels(predOut.n_cols);\n", + " for (arma::uword i = 0; i < predOut.n_cols; ++i)\n", + " {\n", + " predLabels(i) = predOut.col(i).index_max();\n", + " }\n", + " return predLabels;\n", + "}\n", + "\n", + "int main()\n", + "{\n", + " // Dataset is randomly split into validation\n", + " // and training parts with following ratio.\n", + " constexpr double RATIO = 0.1;\n", + "\n", + " // Allow infinite number of iterations until we stopped by EarlyStopAtMinLoss\n", + " constexpr int MAX_ITERATIONS = 0;\n", + "\n", + " // Step size of the optimizer.\n", + " constexpr double STEP_SIZE = 1.2e-3;\n", + "\n", + " // Number of data points in each iteration of SGD.\n", + " constexpr int BATCH_SIZE = 50;\n", + "\n", + " cout << \"Reading data ...\" << endl;\n", + "\n", + " // Labeled dataset that contains data for training is loaded from CSV file.\n", + " // Rows represent features, columns represent data points.\n", + " mat dataset;\n", + "\n", + " // The original file can be downloaded from\n", + " // https://www.kaggle.com/c/digit-recognizer/data\n", + " data::Load(\"/home/viole/swaingotnochill/examples/generating_hand_written_digits_mnist_with_gan/digit-recognizer/train.csv\", dataset, true);\n", + "\n", + " // Split the dataset into training and validation sets.\n", + " mat train, valid;\n", + " data::Split(dataset, train, valid, RATIO);\n", + "\n", + " // The train and valid datasets contain both - the features as well as the\n", + " // class labels. Split these into separate mats.\n", + " const mat trainX = train.submat(1, 0, train.n_rows - 1, train.n_cols - 1);\n", + " const mat validX = valid.submat(1, 0, valid.n_rows - 1, valid.n_cols - 1);\n", + "\n", + " // Labels should specify the class of a data point and be in the interval [0,\n", + " // numClasses).\n", + "\n", + " // Create labels for training and validatiion datasets.\n", + " const mat trainY = train.row(0);\n", + " const mat validY = valid.row(0);\n", + "\n", + " // Specify the NN model. NegativeLogLikelihood is the output layer that\n", + " // is used for classification problem. RandomInitialization means that\n", + " // initial weights are generated randomly in the interval from -1 to 1.\n", + " FFN, RandomInitialization> model;\n", + "\n", + " // Specify the model architecture.\n", + " // In this example, the CNN architecture is chosen similar to LeNet-5.\n", + " // The architecture follows a Conv-ReLU-Pool-Conv-ReLU-Pool-Dense schema. We\n", + " // have used leaky ReLU activation instead of vanilla ReLU. Standard\n", + " // max-pooling has been used for pooling. The first convolution uses 6 filters\n", + " // of size 5x5 (and a stride of 1). The second convolution uses 16 filters of\n", + " // size 5x5 (stride = 1). The final dense layer is connected to a softmax to\n", + " // ensure that we get a valid probability distribution over the output classes\n", + "\n", + " // Layers schema.\n", + " // 28x28x1 --- conv (6 filters of size 5x5. stride = 1) ---> 24x24x6\n", + " // 24x24x6 --------------- Leaky ReLU ---------------------> 24x24x6\n", + " // 24x24x6 --- max pooling (over 2x2 fields. stride = 2) --> 12x12x6\n", + " // 12x12x6 --- conv (16 filters of size 5x5. stride = 1) --> 8x8x16\n", + " // 8x8x16 --------------- Leaky ReLU ---------------------> 8x8x16\n", + " // 8x8x16 --- max pooling (over 2x2 fields. stride = 2) --> 4x4x16\n", + " // 4x4x16 ------------------- Dense ----------------------> 10\n", + "\n", + " // Add the first convolution layer.\n", + " model.Add>(1, // Number of input activation maps.\n", + " 6, // Number of output activation maps.\n", + " 5, // Filter width.\n", + " 5, // Filter height.\n", + " 1, // Stride along width.\n", + " 1, // Stride along height.\n", + " 0, // Padding width.\n", + " 0, // Padding height.\n", + " 28, // Input width.\n", + " 28 // Input height.\n", + " );\n", + "\n", + " // Add first ReLU.\n", + " model.Add>();\n", + "\n", + " // Add first pooling layer. Pools over 2x2 fields in the input.\n", + " model.Add>(2, // Width of field.\n", + " 2, // Height of field.\n", + " 2, // Stride along width.\n", + " 2, // Stride along height.\n", + " true);\n", + "\n", + " // Add the second convolution layer.\n", + " model.Add>(6, // Number of input activation maps.\n", + " 16, // Number of output activation maps.\n", + " 5, // Filter width.\n", + " 5, // Filter height.\n", + " 1, // Stride along width.\n", + " 1, // Stride along height.\n", + " 0, // Padding width.\n", + " 0, // Padding height.\n", + " 12, // Input width.\n", + " 12 // Input height.\n", + " );\n", + "\n", + " // Add the second ReLU.\n", + " model.Add>();\n", + "\n", + " // Add the second pooling layer.\n", + " model.Add>(2, 2, 2, 2, true);\n", + "\n", + " // Add the final dense layer.\n", + " model.Add>(16 * 4 * 4, 10);\n", + " model.Add>();\n", + "\n", + " cout << \"Start training ...\" << endl;\n", + "\n", + " // Set parameters for the Adam optimizer.\n", + " ens::Adam optimizer(\n", + " STEP_SIZE, // Step size of the optimizer.\n", + " BATCH_SIZE, // Batch size. Number of data points that are used in each\n", + " // iteration.\n", + " 0.9, // Exponential decay rate for the first moment estimates.\n", + " 0.999, // Exponential decay rate for the weighted infinity norm estimates.\n", + " 1e-8, // Value used to initialise the mean squared gradient parameter.\n", + " MAX_ITERATIONS, // Max number of iterations.\n", + " 1e-8, // Tolerance.\n", + " true);\n", + "\n", + " // Train the CNN model. If this is the first iteration, weights are\n", + " // randomly initialized between -1 and 1. Otherwise, the values of weights\n", + " // from the previous iteration are used.\n", + " model.Train(trainX,\n", + " trainY,\n", + " optimizer,\n", + " ens::PrintLoss(),\n", + " ens::ProgressBar(),\n", + " // Stop the training using Early Stop at min loss.\n", + " ens::EarlyStopAtMinLoss(\n", + " [&](const arma::mat& /* param */)\n", + " {\n", + " double validationLoss = model.Evaluate(validX, validY);\n", + " std::cout << \"Validation loss: \" << validationLoss\n", + " << \".\" << std::endl;\n", + " return validationLoss;\n", + " }));\n", + "\n", + " // Matrix to store the predictions on train and validation datasets.\n", + " mat predOut;\n", + " // Get predictions on training data points.\n", + " model.Predict(trainX, predOut);\n", + " // Calculate accuracy on training data points.\n", + " arma::Row predLabels = getLabels(predOut);\n", + " double trainAccuracy =\n", + " arma::accu(predLabels == trainY) / (double) trainY.n_elem * 100;\n", + " // Get predictions on validating data points.\n", + " model.Predict(validX, predOut);\n", + " // Calculate accuracy on validating data points.\n", + " predLabels = getLabels(predOut);\n", + " double validAccuracy =\n", + " arma::accu(predLabels == validY) / (double) validY.n_elem * 100;\n", + "\n", + " std::cout << \"Accuracy: train = \" << trainAccuracy << \"%,\"\n", + " << \"\\t valid = \" << validAccuracy << \"%\" << std::endl;\n", + "\n", + " mlpack::data::Save(\"model.bin\", \"model\", model, false);\n", + "\n", + " std::cout << \"Predicting ...\" << std::endl;\n", + "\n", + " // Load test dataset\n", + " // The original file could be download from\n", + " // https://www.kaggle.com/c/digit-recognizer/data\n", + " data::Load(\"/home/viole/swaingotnochill/examples/generating_hand_written_digits_mnist_with_gan/digit-recognizer/test.csv\", dataset, true);\n", + " dataset.shed_row(dataset.n_rows - 1); // Remove labels before predicting.\n", + "\n", + " // Matrix to store the predictions on test dataset.\n", + " mat testPredOut;\n", + " // Get predictions on test data points.\n", + " model.Predict(dataset, testPredOut);\n", + " // Generate labels for the test dataset.\n", + " Row testPred = getLabels(testPredOut);\n", + " std::cout << \"Saving predicted labels to \\\"results.csv.\\\"...\" << std::endl;\n", + "\n", + " // Saving results into Kaggle compatibe CSV file.\n", + " testPred.save(\"results.csv\", arma::csv_ascii);\n", + " std::cout << \"Neural network model is saved to \\\"model.bin\\\"\" << std::endl;\n", + " std::cout << \"Finished\" << std::endl;\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5dbbc367", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "C++14", + "language": "C++14", + "name": "xcpp14" + }, + "language_info": { + "codemirror_mode": "text/x-c++src", + "file_extension": ".cpp", + "mimetype": "text/x-c++src", + "name": "c++", + "version": "14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}