diff --git a/generating_hand_written_digits_mnist_with_gan/generating_handwritten_digits_mnist_with_gan.ipynb b/generating_hand_written_digits_mnist_with_gan/generating_handwritten_digits_mnist_with_gan.ipynb new file mode 100644 index 00000000..4516c038 --- /dev/null +++ b/generating_hand_written_digits_mnist_with_gan/generating_handwritten_digits_mnist_with_gan.ipynb @@ -0,0 +1,353 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "c23013fc", + "metadata": {}, + "outputs": [], + "source": [ + "#include\n", + "#include\n", + "#include\n", + "#include\n", + "#include\n", + "#include\n", + "#include" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "ad7ab4c6", + "metadata": {}, + "outputs": [], + "source": [ + "using namespace mlpack;\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "0c8af9e2", + "metadata": {}, + "outputs": [], + "source": [ + "using namespace mlpack::ann;" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "ff1b2e2f", + "metadata": {}, + "outputs": [], + "source": [ + "using namespace arma;" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "4834841f", + "metadata": {}, + "outputs": [], + "source": [ + "using namespace std;" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "d1dc9cb6", + "metadata": {}, + "outputs": [], + "source": [ + "using namespace ens;" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "7ed7a903", + "metadata": {}, + "outputs": [], + "source": [ + "arma::Row getLabels(arma::mat predOut){\n", + " arma::Row predLabels(predOut.n_cols);\n", + " for( arma::uword i = 0; i < predOut.n_cols; ++i){\n", + " predLabels(i) = predOut.col(i).index_max();\n", + " }\n", + " return predLabels;\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "131fe748", + "metadata": {}, + "outputs": [], + "source": [ + "constexpr double RATIO = 0.1;\n", + "constexpr int MAX_ITERATIONS = 0;\n", + "constexpr double STEP_SIZE = 1.2e-3;\n", + "constexpr int BATCH_SIZE = 50;" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "697a8d63", + "metadata": {}, + "outputs": [], + "source": [ + "arma::mat dataset;" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "64ee4e72", + "metadata": {}, + "outputs": [], + "source": [ + "data::Load(\"/home/viole/swaingotnochill/examples/generating_hand_written_digits_mnist_with_gan/digit-recognizer/train.csv\", dataset, true);" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "d8a3d8da", + "metadata": {}, + "outputs": [], + "source": [ + "arma::mat train, valid;" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "da10fdac", + "metadata": {}, + "outputs": [], + "source": [ + "data::Split(dataset, train, valid, RATIO);" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "a6a484e7", + "metadata": {}, + "outputs": [], + "source": [ + "const arma::mat trainX = train.submat(1, 0, train.n_rows - 1, train.n_cols - 1);\n", + "const arma::mat validX = valid.submat(1, 0, valid.n_rows - 1, valid.n_cols - 1);" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "db267de6", + "metadata": {}, + "outputs": [], + "source": [ + "const arma::mat trainY = train.row(0);\n", + "const arma::mat validY = valid.row(0);" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "7eb2907b", + "metadata": {}, + "outputs": [], + "source": [ + "FFN, RandomInitialization> model;" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "ffc66454", + "metadata": {}, + "outputs": [], + "source": [ + "model.Add> (1,6,5,5,1,1,0,0,28,28);\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "03d0646a", + "metadata": {}, + "outputs": [], + "source": [ + "model.Add>();" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "2fff115e", + "metadata": {}, + "outputs": [], + "source": [ + "model.Add>(2,\n", + " 2,\n", + " 2,\n", + " 2,\n", + " true);" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "462e4ce2", + "metadata": {}, + "outputs": [], + "source": [ + "model.Add>(6,\n", + " 16,\n", + " 5,\n", + " 5,\n", + " 1,\n", + " 1,\n", + " 0,\n", + " 0,\n", + " 12,\n", + " 12);" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "73033233", + "metadata": {}, + "outputs": [], + "source": [ + "model.Add>();" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "ee5d88cc", + "metadata": {}, + "outputs": [], + "source": [ + "model.Add>();" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "7859b83e", + "metadata": {}, + "outputs": [], + "source": [ + "model.Add>(2,2,2,2,true);" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "f423475d", + "metadata": {}, + "outputs": [], + "source": [ + "model.Add>(16*4*4, 10);" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "369063ba", + "metadata": {}, + "outputs": [], + "source": [ + "model.Add>();" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "17dfcdef", + "metadata": {}, + "outputs": [], + "source": [ + "ens::Adam optimizer(STEP_SIZE,\n", + " BATCH_SIZE,\n", + " 0.9,\n", + " 0.999,\n", + " 1e-8,\n", + " MAX_ITERATIONS,\n", + " 1e-8,\n", + " true);" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "c52947c7", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "error: Mat::operator(): index out of bounds\n" + ] + }, + { + "ename": "Standard Exception", + "evalue": "Mat::operator(): index out of bounds", + "output_type": "error", + "traceback": [ + "Standard Exception: Mat::operator(): index out of bounds" + ] + } + ], + "source": [ + "model.Train(trainX,\n", + " trainY,\n", + " optimizer,\n", + " ens::PrintLoss(),\n", + " ens::ProgressBar(),\n", + " ens::EarlyStopAtMinLoss(\n", + " [&](const arma::mat&){\n", + " double validationLoss = model.Evaluate(validX, validY);\n", + " std::cout << \"Validation Loss\" << validationLoss << \".\" << \"\\n\";\n", + " return validationLoss;\n", + " }))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d6060587", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "C++14", + "language": "C++14", + "name": "xcpp14" + }, + "language_info": { + "codemirror_mode": "text/x-c++src", + "file_extension": ".cpp", + "mimetype": "text/x-c++src", + "name": "c++", + "version": "14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/generating_hand_written_digits_mnist_with_gan/mnist_first250_training_4s_and_9s.arm b/generating_hand_written_digits_mnist_with_gan/mnist_first250_training_4s_and_9s.arm new file mode 100644 index 00000000..fd28bbbf Binary files /dev/null and b/generating_hand_written_digits_mnist_with_gan/mnist_first250_training_4s_and_9s.arm differ diff --git a/generating_hand_written_digits_mnist_with_gan/testNotebook.ipynb b/generating_hand_written_digits_mnist_with_gan/testNotebook.ipynb new file mode 100644 index 00000000..067de33d --- /dev/null +++ b/generating_hand_written_digits_mnist_with_gan/testNotebook.ipynb @@ -0,0 +1,317 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "4ffd91b1", + "metadata": {}, + "outputs": [], + "source": [ + "#include\n", + "#include \n", + "\n", + "#include \n", + "#include \n", + "#include \n", + "#include \n", + "#include \n", + "#include\n", + "\n", + "#include \n", + "\n", + "// #include \"catch.hpp\"\n", + "// #include \"test_catch_tools.hpp\"\n", + "// #include \"serialization.hpp\"\n", + "\n", + "// using namespace mlpack;\n", + "// using namespace mlpack::ann;\n", + "// using namespace mlpack::math;\n", + "// using namespace mlpack::regression;\n", + "// using namespace std::placeholders;\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "d2721956", + "metadata": {}, + "outputs": [], + "source": [ + "using namespace mlpack::ann;" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "6a336fdb", + "metadata": {}, + "outputs": [], + "source": [ + "using namespace mlpack;" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "8f61e83e", + "metadata": {}, + "outputs": [], + "source": [ + "using namespace mlpack::math;" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "fec32cd8", + "metadata": {}, + "outputs": [], + "source": [ + "using namespace mlpack::regression;" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "d28df478", + "metadata": {}, + "outputs": [], + "source": [ + "using namespace std::placeholders;" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "5dbbc367", + "metadata": {}, + "outputs": [], + "source": [ + " size_t dNumKernels = 32;\n", + " size_t discriminatorPreTrain = 5;\n", + " size_t batchSize = 5;\n", + " size_t noiseDim = 100;\n", + " size_t generatorUpdateStep = 1;\n", + " size_t numSamples = 10;\n", + " double stepSize = 0.0003;\n", + " double eps = 1e-8;\n", + " size_t numEpoches = 1;\n", + " double tolerance = 1e-5;\n", + " int datasetMaxCols = 10;\n", + " bool shuffle = true;\n", + " double multiplier = 10;" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "7f3190fc", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " batchSize = 5\n", + " generatorUpdateStep = 1\n", + " noiseDim = 100\n", + " numSamples = 10\n", + " stepSize = 0.0003\n", + " numEpoches = 1\n", + " tolerance = 1e-05\n", + " shuffle = true\n" + ] + } + ], + "source": [ + "\n", + " std::cout << std::boolalpha\n", + " << \" batchSize = \" << batchSize << std::endl\n", + " << \" generatorUpdateStep = \" << generatorUpdateStep << std::endl\n", + " << \" noiseDim = \" << noiseDim << std::endl\n", + " << \" numSamples = \" << numSamples << std::endl\n", + " << \" stepSize = \" << stepSize << std::endl\n", + " << \" numEpoches = \" << numEpoches << std::endl\n", + " << \" tolerance = \" << tolerance << std::endl\n", + " << \" shuffle = \" << shuffle << std::endl;\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4dbbb38d", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "In file included from input_line_5:1:\n", + "In file included from /home/viole/anaconda3/envs/notebook/include/xeus/xinterpreter.hpp:13:\n", + "In file included from /home/viole/anaconda3/envs/notebook/bin/../lib/gcc/x86_64-conda-linux-gnu/7.5.0/../../../../x86_64-conda-linux-gnu/include/c++/7.5.0/functional:58:\n", + "\u001b[1m/home/viole/anaconda3/envs/notebook/bin/../lib/gcc/x86_64-conda-linux-gnu/7.5.0/../../../../x86_64-conda-linux-gnu/include/c++/7.5.0/bits/std_function.h:264:34: \u001b[0m\u001b[0;1;31merror: \u001b[0m\u001b[1mno matching constructor for initialization of '(lambda at input_line_16:43:43)'\u001b[0m\n", + " { ::new (__functor._M_access()) _Functor(std::move(__f)); }\n", + "\u001b[0;1;32m ^ ~~~~~~~~~~~~~~\n", + "\u001b[0m\u001b[1m/home/viole/anaconda3/envs/notebook/bin/../lib/gcc/x86_64-conda-linux-gnu/7.5.0/../../../../x86_64-conda-linux-gnu/include/c++/7.5.0/bits/std_function.h:239:4: \u001b[0m\u001b[0;1;30mnote: \u001b[0min instantiation of member function 'std::_Function_base::_Base_manager<(lambda at\n", + " input_line_16:43:43)>::_M_init_functor' requested here\u001b[0m\n", + " { _M_init_functor(__functor, std::move(__f), _Local_storage()); }\n", + "\u001b[0;1;32m ^\n", + "\u001b[0m\u001b[1m/home/viole/anaconda3/envs/notebook/bin/../lib/gcc/x86_64-conda-linux-gnu/7.5.0/../../../../x86_64-conda-linux-gnu/include/c++/7.5.0/bits/std_function.h:693:19: \u001b[0m\u001b[0;1;30mnote: \u001b[0min instantiation of member function 'std::_Function_base::_Base_manager<(lambda at\n", + " input_line_16:43:43)>::_M_init_functor' requested here\u001b[0m\n", + " _My_handler::_M_init_functor(_M_functor, std::move(__f));\n", + "\u001b[0;1;32m ^\n", + "\u001b[0m\u001b[1minput_line_16:43:43: \u001b[0m\u001b[0;1;30mnote: \u001b[0min instantiation of function template specialization 'std::function::function<(lambda at\n", + " input_line_16:43:43), void, void>' requested here\u001b[0m\n", + " std::function noiseFunction = [] () {\n", + "\u001b[0;1;32m ^\n", + "\u001b[0m" + ] + } + ], + "source": [ + " arma::mat trainData;\n", + " trainData.load(\"mnist_first250_training_4s_and_9s.arm\");\n", + " std::cout << arma::size(trainData) << std::endl;\n", + "\n", + " trainData = trainData.cols(0, datasetMaxCols - 1);\n", + "\n", + " size_t numIterations = trainData.n_cols * numEpoches;\n", + " numIterations /= batchSize;\n", + "\n", + " std::cout << \"Dataset loaded (\" << trainData.n_rows << \", \"\n", + " << trainData.n_cols << \")\" << std::endl;\n", + " std::cout << trainData.n_rows << \"--------\" << trainData.n_cols << std::endl;\n", + "\n", + " // Create the Discriminator network.\n", + " FFN > discriminator;\n", + " discriminator.Add >(1, dNumKernels, 5, 5, 1, 1, 2, 2, 28, 28);\n", + " discriminator.Add >();\n", + " discriminator.Add >(2, 2, 2, 2);\n", + " discriminator.Add >(dNumKernels, 2 * dNumKernels, 5, 5, 1, 1,\n", + " 2, 2, 14, 14);\n", + " discriminator.Add >();\n", + " discriminator.Add >(2, 2, 2, 2);\n", + " discriminator.Add >(7 * 7 * 2 * dNumKernels, 1024);\n", + " discriminator.Add >();\n", + " discriminator.Add >(1024, 1);\n", + "\n", + " // Create the Generator network.\n", + " FFN > generator;\n", + " generator.Add >(noiseDim, 3136);\n", + " generator.Add >(3136);\n", + " generator.Add >();\n", + " generator.Add >(1, noiseDim / 2, 3, 3, 2, 2, 1, 1, 56, 56);\n", + " generator.Add >(39200);\n", + " generator.Add >();\n", + " generator.Add >(28, 28, 56, 56, noiseDim / 2);\n", + " generator.Add >(noiseDim / 2, noiseDim / 4, 3, 3, 2, 2, 1, 1,\n", + " 56, 56);\n", + " generator.Add >(19600);\n", + " generator.Add >();\n", + " generator.Add >(28, 28, 56, 56, noiseDim / 4);\n", + " generator.Add >(noiseDim / 4, 1, 3, 3, 2, 2, 1, 1, 56, 56);\n", + " generator.Add >();\n", + "\n", + " // Create GAN.\n", + " GaussianInitialization gaussian(0, 1);\n", + " ens::Adam optimizer(stepSize, batchSize, 0.9, 0.999, eps, numIterations,\n", + " tolerance, shuffle);\n", + " std::function noiseFunction = [] () {\n", + " return math::RandNormal(0, 1);};\n", + " GAN >, GaussianInitialization,\n", + " std::function > gan(generator, discriminator,\n", + " gaussian, noiseFunction, noiseDim, batchSize, generatorUpdateStep,\n", + " discriminatorPreTrain, multiplier);\n", + "\n", + " std::cout << \"Training...\" << std::endl;\n", + " std::stringstream stream;\n", + " double objVal = gan.Train(trainData, optimizer, ens::ProgressBar(70, stream));\n", + "// REQUIRE(stream.str().length() > 0);\n", + "// REQUIRE(std::isfinite(objVal) == true);\n", + "\n", + "// // Generate samples.\n", + "// std::cout << \"Sampling...\" << std::endl;\n", + "// arma::mat noise(noiseDim, batchSize);\n", + "// size_t dim = std::sqrt(trainData.n_rows);\n", + "// arma::mat generatedData(2 * dim, dim * numSamples);\n", + "\n", + "// for (size_t i = 0; i < numSamples; ++i)\n", + "// {\n", + "// arma::mat samples;\n", + "// noise.imbue( [&]() { return noiseFunction(); } );\n", + "\n", + "// gan.Generator().Forward(noise, samples);\n", + "// samples.reshape(dim, dim);\n", + "// samples = samples.t();\n", + "\n", + "// generatedData.submat(0, i * dim, dim - 1, i * dim + dim - 1) = samples;\n", + "\n", + "// samples = trainData.col(math::RandInt(0, trainData.n_cols));\n", + "// samples.reshape(dim, dim);\n", + "// samples = samples.t();\n", + "\n", + "// generatedData.submat(dim,\n", + "// i * dim, 2 * dim - 1, i * dim + dim - 1) = samples;\n", + "// }\n", + "\n", + "// std::cout << \"Output generated!\" << std::endl;\n", + "\n", + "// // Check that Serialization is working correctly.\n", + "// arma::mat orgPredictions;\n", + "// gan.Predict(noise, orgPredictions);\n", + "\n", + "// GAN >, GaussianInitialization,\n", + "// std::function > ganText(generator, discriminator,\n", + "// gaussian, noiseFunction, noiseDim, batchSize, generatorUpdateStep,\n", + "// discriminatorPreTrain, multiplier);\n", + "\n", + "// GAN >, GaussianInitialization,\n", + "// std::function > ganXml(generator, discriminator,\n", + "// gaussian, noiseFunction, noiseDim, batchSize, generatorUpdateStep,\n", + "// discriminatorPreTrain, multiplier);\n", + "\n", + "// GAN >, GaussianInitialization,\n", + "// std::function > ganBinary(generator, discriminator,\n", + "// gaussian, noiseFunction, noiseDim, batchSize, generatorUpdateStep,\n", + "// discriminatorPreTrain, multiplier);\n", + "\n", + "// SerializeObjectAll(gan, ganXml, ganText, ganBinary);\n", + "\n", + "// arma::mat predictions, xmlPredictions, textPredictions, binaryPredictions;\n", + "// gan.Predict(noise, predictions);\n", + "// ganXml.Predict(noise, xmlPredictions);\n", + "// ganText.Predict(noise, textPredictions);\n", + "// ganBinary.Predict(noise, binaryPredictions);\n", + "\n", + "// CheckMatrices(orgPredictions, predictions);\n", + "// CheckMatrices(orgPredictions, xmlPredictions);\n", + "// CheckMatrices(orgPredictions, textPredictions);\n", + "// CheckMatrices(orgPredictions, binaryPredictions);" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f2af403d", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "C++14", + "language": "C++14", + "name": "xcpp14" + }, + "language_info": { + "codemirror_mode": "text/x-c++src", + "file_extension": ".cpp", + "mimetype": "text/x-c++src", + "name": "c++", + "version": "14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}