diff --git a/generating_hand_written_digits_mnist_with_gan/mnist_first250_training_4s_and_9s.arm b/generating_hand_written_digits_mnist_with_gan/mnist_first250_training_4s_and_9s.arm new file mode 100644 index 00000000..fd28bbbf Binary files /dev/null and b/generating_hand_written_digits_mnist_with_gan/mnist_first250_training_4s_and_9s.arm differ diff --git a/generating_hand_written_digits_mnist_with_gan/testNotebook.ipynb b/generating_hand_written_digits_mnist_with_gan/testNotebook.ipynb index 01a866d4..067de33d 100644 --- a/generating_hand_written_digits_mnist_with_gan/testNotebook.ipynb +++ b/generating_hand_written_digits_mnist_with_gan/testNotebook.ipynb @@ -2,268 +2,297 @@ "cells": [ { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "id": "4ffd91b1", "metadata": {}, + "outputs": [], + "source": [ + "#include\n", + "#include \n", + "\n", + "#include \n", + "#include \n", + "#include \n", + "#include \n", + "#include \n", + "#include\n", + "\n", + "#include \n", + "\n", + "// #include \"catch.hpp\"\n", + "// #include \"test_catch_tools.hpp\"\n", + "// #include \"serialization.hpp\"\n", + "\n", + "// using namespace mlpack;\n", + "// using namespace mlpack::ann;\n", + "// using namespace mlpack::math;\n", + "// using namespace mlpack::regression;\n", + "// using namespace std::placeholders;\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "d2721956", + "metadata": {}, + "outputs": [], + "source": [ + "using namespace mlpack::ann;" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "6a336fdb", + "metadata": {}, + "outputs": [], + "source": [ + "using namespace mlpack;" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "8f61e83e", + "metadata": {}, + "outputs": [], + "source": [ + "using namespace mlpack::math;" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "fec32cd8", + "metadata": {}, + "outputs": [], + "source": [ + "using namespace mlpack::regression;" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "d28df478", + "metadata": {}, + "outputs": [], + "source": [ + "using namespace std::placeholders;" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "5dbbc367", + "metadata": {}, + "outputs": [], + "source": [ + " size_t dNumKernels = 32;\n", + " size_t discriminatorPreTrain = 5;\n", + " size_t batchSize = 5;\n", + " size_t noiseDim = 100;\n", + " size_t generatorUpdateStep = 1;\n", + " size_t numSamples = 10;\n", + " double stepSize = 0.0003;\n", + " double eps = 1e-8;\n", + " size_t numEpoches = 1;\n", + " double tolerance = 1e-5;\n", + " int datasetMaxCols = 10;\n", + " bool shuffle = true;\n", + " double multiplier = 10;" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "7f3190fc", + "metadata": {}, "outputs": [ { - "name": "stderr", + "name": "stdout", "output_type": "stream", "text": [ - "\u001b[1minput_line_9:11:1: \u001b[0m\u001b[0;1;31merror: \u001b[0m\u001b[1mfunction definition is not allowed here\u001b[0m\n", - "{\n", - "\u001b[0;1;32m^\n", - "\u001b[0m\u001b[1minput_line_9:20:1: \u001b[0m\u001b[0;1;31merror: \u001b[0m\u001b[1mfunction definition is not allowed here\u001b[0m\n", - "{\n", - "\u001b[0;1;32m^\n", - "\u001b[0m" + " batchSize = 5\n", + " generatorUpdateStep = 1\n", + " noiseDim = 100\n", + " numSamples = 10\n", + " stepSize = 0.0003\n", + " numEpoches = 1\n", + " tolerance = 1e-05\n", + " shuffle = true\n" ] - }, + } + ], + "source": [ + "\n", + " std::cout << std::boolalpha\n", + " << \" batchSize = \" << batchSize << std::endl\n", + " << \" generatorUpdateStep = \" << generatorUpdateStep << std::endl\n", + " << \" noiseDim = \" << noiseDim << std::endl\n", + " << \" numSamples = \" << numSamples << std::endl\n", + " << \" stepSize = \" << stepSize << std::endl\n", + " << \" numEpoches = \" << numEpoches << std::endl\n", + " << \" tolerance = \" << tolerance << std::endl\n", + " << \" shuffle = \" << shuffle << std::endl;\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4dbbb38d", + "metadata": {}, + "outputs": [ { - "ename": "Interpreter Error", - "evalue": "", - "output_type": "error", - "traceback": [ - "Interpreter Error: " + "name": "stderr", + "output_type": "stream", + "text": [ + "In file included from input_line_5:1:\n", + "In file included from /home/viole/anaconda3/envs/notebook/include/xeus/xinterpreter.hpp:13:\n", + "In file included from /home/viole/anaconda3/envs/notebook/bin/../lib/gcc/x86_64-conda-linux-gnu/7.5.0/../../../../x86_64-conda-linux-gnu/include/c++/7.5.0/functional:58:\n", + "\u001b[1m/home/viole/anaconda3/envs/notebook/bin/../lib/gcc/x86_64-conda-linux-gnu/7.5.0/../../../../x86_64-conda-linux-gnu/include/c++/7.5.0/bits/std_function.h:264:34: \u001b[0m\u001b[0;1;31merror: \u001b[0m\u001b[1mno matching constructor for initialization of '(lambda at input_line_16:43:43)'\u001b[0m\n", + " { ::new (__functor._M_access()) _Functor(std::move(__f)); }\n", + "\u001b[0;1;32m ^ ~~~~~~~~~~~~~~\n", + "\u001b[0m\u001b[1m/home/viole/anaconda3/envs/notebook/bin/../lib/gcc/x86_64-conda-linux-gnu/7.5.0/../../../../x86_64-conda-linux-gnu/include/c++/7.5.0/bits/std_function.h:239:4: \u001b[0m\u001b[0;1;30mnote: \u001b[0min instantiation of member function 'std::_Function_base::_Base_manager<(lambda at\n", + " input_line_16:43:43)>::_M_init_functor' requested here\u001b[0m\n", + " { _M_init_functor(__functor, std::move(__f), _Local_storage()); }\n", + "\u001b[0;1;32m ^\n", + "\u001b[0m\u001b[1m/home/viole/anaconda3/envs/notebook/bin/../lib/gcc/x86_64-conda-linux-gnu/7.5.0/../../../../x86_64-conda-linux-gnu/include/c++/7.5.0/bits/std_function.h:693:19: \u001b[0m\u001b[0;1;30mnote: \u001b[0min instantiation of member function 'std::_Function_base::_Base_manager<(lambda at\n", + " input_line_16:43:43)>::_M_init_functor' requested here\u001b[0m\n", + " _My_handler::_M_init_functor(_M_functor, std::move(__f));\n", + "\u001b[0;1;32m ^\n", + "\u001b[0m\u001b[1minput_line_16:43:43: \u001b[0m\u001b[0;1;30mnote: \u001b[0min instantiation of function template specialization 'std::function::function<(lambda at\n", + " input_line_16:43:43), void, void>' requested here\u001b[0m\n", + " std::function noiseFunction = [] () {\n", + "\u001b[0;1;32m ^\n", + "\u001b[0m" ] } ], "source": [ - "\n", - "/**\n", - " * An example of using Convolutional Neural Network (CNN) for\n", - " * solving Digit Recognizer problem from Kaggle website.\n", - " *\n", - " * The full description of a problem as well as datasets for training\n", - " * and testing are available here https://www.kaggle.com/c/digit-recognizer\n", - " *\n", - " * mlpack is free software; you may redistribute it and/or modify it under the\n", - " * terms of the 3-clause BSD license. You should have received a copy of the\n", - " * 3-clause BSD license along with mlpack. If not, see\n", - " * http://www.opensource.org/licenses/BSD-3-Clause for more information.\n", - " *\n", - " * @author Daivik Nema\n", - " */\n", - "\n", - "#include \n", - "#include \n", - "\n", - "#include \n", - "#include \n", - "\n", - "#include \n", - "\n", - "#if ((ENS_VERSION_MAJOR < 2) || ((ENS_VERSION_MAJOR == 2) && (ENS_VERSION_MINOR < 13)))\n", - " #error \"need ensmallen version 2.13.0 or later\"\n", - "#endif\n", - "\n", - "using namespace mlpack;\n", - "using namespace mlpack::ann;\n", - "\n", - "using namespace arma;\n", - "using namespace std;\n", - "\n", - "using namespace ens;\n", - "\n", - "arma::Row getLabels(arma::mat predOut)\n", - "{\n", - " arma::Row predLabels(predOut.n_cols);\n", - " for (arma::uword i = 0; i < predOut.n_cols; ++i)\n", - " {\n", - " predLabels(i) = predOut.col(i).index_max();\n", - " }\n", - " return predLabels;\n", - "}\n", - "\n", - "int main()\n", - "{\n", - " // Dataset is randomly split into validation\n", - " // and training parts with following ratio.\n", - " constexpr double RATIO = 0.1;\n", - "\n", - " // Allow infinite number of iterations until we stopped by EarlyStopAtMinLoss\n", - " constexpr int MAX_ITERATIONS = 0;\n", - "\n", - " // Step size of the optimizer.\n", - " constexpr double STEP_SIZE = 1.2e-3;\n", - "\n", - " // Number of data points in each iteration of SGD.\n", - " constexpr int BATCH_SIZE = 50;\n", - "\n", - " cout << \"Reading data ...\" << endl;\n", - "\n", - " // Labeled dataset that contains data for training is loaded from CSV file.\n", - " // Rows represent features, columns represent data points.\n", - " mat dataset;\n", - "\n", - " // The original file can be downloaded from\n", - " // https://www.kaggle.com/c/digit-recognizer/data\n", - " data::Load(\"/home/viole/swaingotnochill/examples/generating_hand_written_digits_mnist_with_gan/digit-recognizer/train.csv\", dataset, true);\n", - "\n", - " // Split the dataset into training and validation sets.\n", - " mat train, valid;\n", - " data::Split(dataset, train, valid, RATIO);\n", - "\n", - " // The train and valid datasets contain both - the features as well as the\n", - " // class labels. Split these into separate mats.\n", - " const mat trainX = train.submat(1, 0, train.n_rows - 1, train.n_cols - 1);\n", - " const mat validX = valid.submat(1, 0, valid.n_rows - 1, valid.n_cols - 1);\n", - "\n", - " // Labels should specify the class of a data point and be in the interval [0,\n", - " // numClasses).\n", - "\n", - " // Create labels for training and validatiion datasets.\n", - " const mat trainY = train.row(0);\n", - " const mat validY = valid.row(0);\n", - "\n", - " // Specify the NN model. NegativeLogLikelihood is the output layer that\n", - " // is used for classification problem. RandomInitialization means that\n", - " // initial weights are generated randomly in the interval from -1 to 1.\n", - " FFN, RandomInitialization> model;\n", - "\n", - " // Specify the model architecture.\n", - " // In this example, the CNN architecture is chosen similar to LeNet-5.\n", - " // The architecture follows a Conv-ReLU-Pool-Conv-ReLU-Pool-Dense schema. We\n", - " // have used leaky ReLU activation instead of vanilla ReLU. Standard\n", - " // max-pooling has been used for pooling. The first convolution uses 6 filters\n", - " // of size 5x5 (and a stride of 1). The second convolution uses 16 filters of\n", - " // size 5x5 (stride = 1). The final dense layer is connected to a softmax to\n", - " // ensure that we get a valid probability distribution over the output classes\n", - "\n", - " // Layers schema.\n", - " // 28x28x1 --- conv (6 filters of size 5x5. stride = 1) ---> 24x24x6\n", - " // 24x24x6 --------------- Leaky ReLU ---------------------> 24x24x6\n", - " // 24x24x6 --- max pooling (over 2x2 fields. stride = 2) --> 12x12x6\n", - " // 12x12x6 --- conv (16 filters of size 5x5. stride = 1) --> 8x8x16\n", - " // 8x8x16 --------------- Leaky ReLU ---------------------> 8x8x16\n", - " // 8x8x16 --- max pooling (over 2x2 fields. stride = 2) --> 4x4x16\n", - " // 4x4x16 ------------------- Dense ----------------------> 10\n", - "\n", - " // Add the first convolution layer.\n", - " model.Add>(1, // Number of input activation maps.\n", - " 6, // Number of output activation maps.\n", - " 5, // Filter width.\n", - " 5, // Filter height.\n", - " 1, // Stride along width.\n", - " 1, // Stride along height.\n", - " 0, // Padding width.\n", - " 0, // Padding height.\n", - " 28, // Input width.\n", - " 28 // Input height.\n", - " );\n", - "\n", - " // Add first ReLU.\n", - " model.Add>();\n", - "\n", - " // Add first pooling layer. Pools over 2x2 fields in the input.\n", - " model.Add>(2, // Width of field.\n", - " 2, // Height of field.\n", - " 2, // Stride along width.\n", - " 2, // Stride along height.\n", - " true);\n", - "\n", - " // Add the second convolution layer.\n", - " model.Add>(6, // Number of input activation maps.\n", - " 16, // Number of output activation maps.\n", - " 5, // Filter width.\n", - " 5, // Filter height.\n", - " 1, // Stride along width.\n", - " 1, // Stride along height.\n", - " 0, // Padding width.\n", - " 0, // Padding height.\n", - " 12, // Input width.\n", - " 12 // Input height.\n", - " );\n", - "\n", - " // Add the second ReLU.\n", - " model.Add>();\n", - "\n", - " // Add the second pooling layer.\n", - " model.Add>(2, 2, 2, 2, true);\n", - "\n", - " // Add the final dense layer.\n", - " model.Add>(16 * 4 * 4, 10);\n", - " model.Add>();\n", - "\n", - " cout << \"Start training ...\" << endl;\n", - "\n", - " // Set parameters for the Adam optimizer.\n", - " ens::Adam optimizer(\n", - " STEP_SIZE, // Step size of the optimizer.\n", - " BATCH_SIZE, // Batch size. Number of data points that are used in each\n", - " // iteration.\n", - " 0.9, // Exponential decay rate for the first moment estimates.\n", - " 0.999, // Exponential decay rate for the weighted infinity norm estimates.\n", - " 1e-8, // Value used to initialise the mean squared gradient parameter.\n", - " MAX_ITERATIONS, // Max number of iterations.\n", - " 1e-8, // Tolerance.\n", - " true);\n", - "\n", - " // Train the CNN model. If this is the first iteration, weights are\n", - " // randomly initialized between -1 and 1. Otherwise, the values of weights\n", - " // from the previous iteration are used.\n", - " model.Train(trainX,\n", - " trainY,\n", - " optimizer,\n", - " ens::PrintLoss(),\n", - " ens::ProgressBar(),\n", - " // Stop the training using Early Stop at min loss.\n", - " ens::EarlyStopAtMinLoss(\n", - " [&](const arma::mat& /* param */)\n", - " {\n", - " double validationLoss = model.Evaluate(validX, validY);\n", - " std::cout << \"Validation loss: \" << validationLoss\n", - " << \".\" << std::endl;\n", - " return validationLoss;\n", - " }));\n", - "\n", - " // Matrix to store the predictions on train and validation datasets.\n", - " mat predOut;\n", - " // Get predictions on training data points.\n", - " model.Predict(trainX, predOut);\n", - " // Calculate accuracy on training data points.\n", - " arma::Row predLabels = getLabels(predOut);\n", - " double trainAccuracy =\n", - " arma::accu(predLabels == trainY) / (double) trainY.n_elem * 100;\n", - " // Get predictions on validating data points.\n", - " model.Predict(validX, predOut);\n", - " // Calculate accuracy on validating data points.\n", - " predLabels = getLabels(predOut);\n", - " double validAccuracy =\n", - " arma::accu(predLabels == validY) / (double) validY.n_elem * 100;\n", - "\n", - " std::cout << \"Accuracy: train = \" << trainAccuracy << \"%,\"\n", - " << \"\\t valid = \" << validAccuracy << \"%\" << std::endl;\n", - "\n", - " mlpack::data::Save(\"model.bin\", \"model\", model, false);\n", - "\n", - " std::cout << \"Predicting ...\" << std::endl;\n", - "\n", - " // Load test dataset\n", - " // The original file could be download from\n", - " // https://www.kaggle.com/c/digit-recognizer/data\n", - " data::Load(\"/home/viole/swaingotnochill/examples/generating_hand_written_digits_mnist_with_gan/digit-recognizer/test.csv\", dataset, true);\n", - " dataset.shed_row(dataset.n_rows - 1); // Remove labels before predicting.\n", - "\n", - " // Matrix to store the predictions on test dataset.\n", - " mat testPredOut;\n", - " // Get predictions on test data points.\n", - " model.Predict(dataset, testPredOut);\n", - " // Generate labels for the test dataset.\n", - " Row testPred = getLabels(testPredOut);\n", - " std::cout << \"Saving predicted labels to \\\"results.csv.\\\"...\" << std::endl;\n", - "\n", - " // Saving results into Kaggle compatibe CSV file.\n", - " testPred.save(\"results.csv\", arma::csv_ascii);\n", - " std::cout << \"Neural network model is saved to \\\"model.bin\\\"\" << std::endl;\n", - " std::cout << \"Finished\" << std::endl;\n", - "}" + " arma::mat trainData;\n", + " trainData.load(\"mnist_first250_training_4s_and_9s.arm\");\n", + " std::cout << arma::size(trainData) << std::endl;\n", + "\n", + " trainData = trainData.cols(0, datasetMaxCols - 1);\n", + "\n", + " size_t numIterations = trainData.n_cols * numEpoches;\n", + " numIterations /= batchSize;\n", + "\n", + " std::cout << \"Dataset loaded (\" << trainData.n_rows << \", \"\n", + " << trainData.n_cols << \")\" << std::endl;\n", + " std::cout << trainData.n_rows << \"--------\" << trainData.n_cols << std::endl;\n", + "\n", + " // Create the Discriminator network.\n", + " FFN > discriminator;\n", + " discriminator.Add >(1, dNumKernels, 5, 5, 1, 1, 2, 2, 28, 28);\n", + " discriminator.Add >();\n", + " discriminator.Add >(2, 2, 2, 2);\n", + " discriminator.Add >(dNumKernels, 2 * dNumKernels, 5, 5, 1, 1,\n", + " 2, 2, 14, 14);\n", + " discriminator.Add >();\n", + " discriminator.Add >(2, 2, 2, 2);\n", + " discriminator.Add >(7 * 7 * 2 * dNumKernels, 1024);\n", + " discriminator.Add >();\n", + " discriminator.Add >(1024, 1);\n", + "\n", + " // Create the Generator network.\n", + " FFN > generator;\n", + " generator.Add >(noiseDim, 3136);\n", + " generator.Add >(3136);\n", + " generator.Add >();\n", + " generator.Add >(1, noiseDim / 2, 3, 3, 2, 2, 1, 1, 56, 56);\n", + " generator.Add >(39200);\n", + " generator.Add >();\n", + " generator.Add >(28, 28, 56, 56, noiseDim / 2);\n", + " generator.Add >(noiseDim / 2, noiseDim / 4, 3, 3, 2, 2, 1, 1,\n", + " 56, 56);\n", + " generator.Add >(19600);\n", + " generator.Add >();\n", + " generator.Add >(28, 28, 56, 56, noiseDim / 4);\n", + " generator.Add >(noiseDim / 4, 1, 3, 3, 2, 2, 1, 1, 56, 56);\n", + " generator.Add >();\n", + "\n", + " // Create GAN.\n", + " GaussianInitialization gaussian(0, 1);\n", + " ens::Adam optimizer(stepSize, batchSize, 0.9, 0.999, eps, numIterations,\n", + " tolerance, shuffle);\n", + " std::function noiseFunction = [] () {\n", + " return math::RandNormal(0, 1);};\n", + " GAN >, GaussianInitialization,\n", + " std::function > gan(generator, discriminator,\n", + " gaussian, noiseFunction, noiseDim, batchSize, generatorUpdateStep,\n", + " discriminatorPreTrain, multiplier);\n", + "\n", + " std::cout << \"Training...\" << std::endl;\n", + " std::stringstream stream;\n", + " double objVal = gan.Train(trainData, optimizer, ens::ProgressBar(70, stream));\n", + "// REQUIRE(stream.str().length() > 0);\n", + "// REQUIRE(std::isfinite(objVal) == true);\n", + "\n", + "// // Generate samples.\n", + "// std::cout << \"Sampling...\" << std::endl;\n", + "// arma::mat noise(noiseDim, batchSize);\n", + "// size_t dim = std::sqrt(trainData.n_rows);\n", + "// arma::mat generatedData(2 * dim, dim * numSamples);\n", + "\n", + "// for (size_t i = 0; i < numSamples; ++i)\n", + "// {\n", + "// arma::mat samples;\n", + "// noise.imbue( [&]() { return noiseFunction(); } );\n", + "\n", + "// gan.Generator().Forward(noise, samples);\n", + "// samples.reshape(dim, dim);\n", + "// samples = samples.t();\n", + "\n", + "// generatedData.submat(0, i * dim, dim - 1, i * dim + dim - 1) = samples;\n", + "\n", + "// samples = trainData.col(math::RandInt(0, trainData.n_cols));\n", + "// samples.reshape(dim, dim);\n", + "// samples = samples.t();\n", + "\n", + "// generatedData.submat(dim,\n", + "// i * dim, 2 * dim - 1, i * dim + dim - 1) = samples;\n", + "// }\n", + "\n", + "// std::cout << \"Output generated!\" << std::endl;\n", + "\n", + "// // Check that Serialization is working correctly.\n", + "// arma::mat orgPredictions;\n", + "// gan.Predict(noise, orgPredictions);\n", + "\n", + "// GAN >, GaussianInitialization,\n", + "// std::function > ganText(generator, discriminator,\n", + "// gaussian, noiseFunction, noiseDim, batchSize, generatorUpdateStep,\n", + "// discriminatorPreTrain, multiplier);\n", + "\n", + "// GAN >, GaussianInitialization,\n", + "// std::function > ganXml(generator, discriminator,\n", + "// gaussian, noiseFunction, noiseDim, batchSize, generatorUpdateStep,\n", + "// discriminatorPreTrain, multiplier);\n", + "\n", + "// GAN >, GaussianInitialization,\n", + "// std::function > ganBinary(generator, discriminator,\n", + "// gaussian, noiseFunction, noiseDim, batchSize, generatorUpdateStep,\n", + "// discriminatorPreTrain, multiplier);\n", + "\n", + "// SerializeObjectAll(gan, ganXml, ganText, ganBinary);\n", + "\n", + "// arma::mat predictions, xmlPredictions, textPredictions, binaryPredictions;\n", + "// gan.Predict(noise, predictions);\n", + "// ganXml.Predict(noise, xmlPredictions);\n", + "// ganText.Predict(noise, textPredictions);\n", + "// ganBinary.Predict(noise, binaryPredictions);\n", + "\n", + "// CheckMatrices(orgPredictions, predictions);\n", + "// CheckMatrices(orgPredictions, xmlPredictions);\n", + "// CheckMatrices(orgPredictions, textPredictions);\n", + "// CheckMatrices(orgPredictions, binaryPredictions);" ] }, { "cell_type": "code", "execution_count": null, - "id": "5dbbc367", + "id": "f2af403d", "metadata": {}, "outputs": [], "source": []