diff --git a/mnist_gan/Makefile b/mnist_gan/Makefile new file mode 100644 index 00000000..a886421a --- /dev/null +++ b/mnist_gan/Makefile @@ -0,0 +1,36 @@ +TARGET := mnist_gan_generate +SRC := mnist_gan_generate.cpp +LIBS_NAME := armadillo mlpack + +CXX := g++ +CXXFLAGS += -std=c++11 -Wall -Wextra -O3 -DNDEBUG +# Use these CXXFLAGS instead if you want to compile with debugging symbols and +# without optimizations. +# CXXFLAGS += -std=c++11 -Wall -Wextra -g -O0 +LDFLAGS += -fopenmp +LDFLAGS += -lboost_serialization +LDFLAGS += -larmadillo +LDFLAGS += -L. # /path to mlpack library if installed locally. +# path: mlpack/build/lib. +# Add header directories for any includes that aren't on the +# default compiler search path. +INCLFLAGS := -I. +CXXFLAGS += $(INCLFLAGS) + +OBJS := $(SRC:.cpp=.o) +LIBS := $(addprefix -l,$(LIBS_NAME)) +CLEAN_LIST := $(TARGET) $(OBJS) + +# default rule +default: all + +$(TARGET): $(OBJS) + $(CXX) $(CXXFLAGS) $(OBJS) -o $(TARGET) $(LDFLAGS) $(LIBS) + +.PHONY: all +all: $(TARGET) + +.PHONY: clean +clean: + @echo CLEAN $(CLEAN_LIST) + @rm -f $(CLEAN_LIST) \ No newline at end of file diff --git a/mnist_gan/dataset/mnist_first250_training_4s_and_9s.arm b/mnist_gan/dataset/mnist_first250_training_4s_and_9s.arm new file mode 100644 index 00000000..fd28bbbf Binary files /dev/null and b/mnist_gan/dataset/mnist_first250_training_4s_and_9s.arm differ diff --git a/mnist_gan/mnist_gan.cpp b/mnist_gan/mnist_gan.cpp new file mode 100644 index 00000000..47f8390f --- /dev/null +++ b/mnist_gan/mnist_gan.cpp @@ -0,0 +1,213 @@ +#include +#include + +#include +#include +#include +#include +#include + +#include + +using namespace mlpack; +using namespace mlpack::data; +using namespace mlpack::ann; +using namespace mlpack::math; +using namespace mlpack::regression; +using namespace std::placeholders; + + +int main() +{ + size_t dNumKernels = 32; + size_t discriminatorPreTrain = 5; + size_t batchSize = 5; + size_t noiseDim = 100; + size_t generatorUpdateStep = 1; + size_t numSamples = 10; + size_t cycles = 10; + size_t numEpoches = 25; + double stepSize = 0.0003; + double trainRatio = 0.8; + double eps = 1e-8; + double tolerance = 1e-5; + bool shuffle = true; + double multiplier = 10; + int datasetMaxCols = 10; + + std::cout << std::boolalpha + << " batchSize = " << batchSize << std::endl + << " generatorUpdateStep = " << generatorUpdateStep << std::endl + << " noiseDim = " << noiseDim << std::endl + << " numSamples = " << numSamples << std::endl + << " stepSize = " << stepSize << std::endl + << " numEpochs = " << numEpoches << std::endl + << " shuffle = " << shuffle << std::endl; + + arma::mat mnistDataset; + mnistDataset.load("./dataset/mnist_first250_training_4s_and_9s.arm"); + + std::cout << "Dataset Shape: " << (mnistDataset.n_rows, mnistDataset.n_cols) << std::endl; + std::cout << arma::size(mnistDataset) << std::endl; + + mnistDataset = mnistDataset.cols(0, datasetMaxCols-1); + size_t numIterations = mnistDataset.n_cols * numEpoches; + numIterations /= batchSize; + + std::cout << "MnistDataset No. of rows: " << mnistDataset.n_rows << std::endl; + + /** + * @brief Model Architecture: + * + * Discriminator: + * 28x28x1-----------> conv (32 filters of size 5x5, + * stride = 1, padding = 2)----------> 28x28x32 + * 28x28x32----------> ReLU -----------------------------> 28x28x32 + * 28x28x32----------> Mean pooling ---------------------> 14x14x32 + * 14x14x32----------> conv (64 filters of size 5x5, + * stride = 1, padding = 2)------> 14x14x64 + * 14x14x64----------> ReLU -----------------------------> 14x14x64 + * 14x14x64----------> Mean pooling ---------------------> 7x7x64 + * 7x7x64------------> Linear Layer ---------------------> 1024 + * 1024--------------> ReLU -----------------------------> 1024 + * 1024 -------------> Linear ---------------------------> 1 + * + * + * Generator: + * noiseDim---------> Linear ---------------------------> 3136 + * 3136 ------------> BatchNormalizaton ----------------> 3136 + * 3136 ------------> ReLu Layer -----------------------> 3136 + * 56x56x1 ---------> conv(1 filter of size 3x3, + * stride = 2, padding = 1)----> 28x28x(noiseDim/2) + * 28x28x(noiseDim/2)----> BatchNormalizaton -----------> 28x28x(noiseDim/2) + * 28x28x(noiseDim/2)----> ReLu Layer-------------------> 28x28x(noiseDim/2) + * 28x28x(noiseDim/2) ----> BilinearInterpolation ------> 56x56x(noiseDim/2) + * 56x56x(noiseDim/2) -----> conv((noiseDim/2) filters + * of size 3x3,stride = 2, + * padding = 1)----------> 28x28x(noiseDim/4) + * 28x28x(noiseDim/4) ----->BatchNormalization----------> 28x28x(noiseDim/4) + * 28x28x(noiseDim/4) ------> ReLu Layer ---------------> 28x28x(noiseDim/4) + * 28x28x(noiseDim/4) ------> BilinearInterpolation ----> 56x56x(noiseDim/4) + * 56x56x(noiseDim/4) ------> conv((noiseDim/4) filters + * of size 3x3, stride = 2, + * padding = 1)-------> 28x28x1 + * 28x28x1 ----------> tanh layer ----------------------> 28x28x1 + * + * + * Note: Output of a Convolution layer = [(W-K+2P)/S + 1] + * where, W : Size of input volume + * K : Kernel size + * P : Padding + * S : Stride + */ + + // Creating the Discriminator network. + FFN > discriminator; + discriminator.Add >(1, // Number of input activation maps + dNumKernels, // Number of output activation maps + 5, // Filter width + 5, // Filter height + 1, // Stride along width + 1, // Stride along height + 2, // Padding width + 2, // Padding height + 28, // Input widht + 28); // Input height + // Adding first ReLU. + discriminator.Add >(); + // Adding mean pooling layer. + discriminator.Add >(2, 2, 2, 2); + // Adding second convolution layer. + discriminator.Add >(dNumKernels, 2 * dNumKernels, 5, 5, 1, 1, + 2, 2, 14, 14); + // Adding second ReLU. + discriminator.Add >(); + // Adding second mean pooling layer. + discriminator.Add >(2, 2, 2, 2); + // Adding linear layer. + discriminator.Add >(7 * 7 * 2 * dNumKernels, 1024); + // Adding third ReLU. + discriminator.Add >(); + // Adding final layer. + discriminator.Add >(1024, 1); + + // Creating the Generator network. + FFN > generator; + generator.Add >(noiseDim, 3136); + generator.Add >(3136); + generator.Add >(); + generator.Add >(1, // Number of input activation maps. + noiseDim / 2, // Number of output activation maps. + 3, // Filter width. + 3, // Filter height. + 2, // Stride along width. + 2, // Stride along height. + 1, // Padding width. + 1, // Padding height. + 56, // input width. + 56); // input height. + // Adding first batch normalization layer. + generator.Add >(39200); + // Adding first ReLU. + generator.Add >(); + // Adding a bilinear interpolation layer. + generator.Add >(28, 28, 56, 56, noiseDim / 2); + // Adding second convolution layer. + generator.Add >(noiseDim / 2, noiseDim / 4, 3, 3, 2, 2, 1, 1, + 56, 56); + // Adding second batch normalization layer. + generator.Add >(19600); + // Adding second ReLU. + generator.Add >(); + // Adding second bilinear interpolation layer. + generator.Add >(28, 28, 56, 56, noiseDim / 4); + // Adding third convolution layer. + generator.Add >(noiseDim / 4, 1, 3, 3, 2, 2, 1, 1, 56, 56); + // Adding final tanh layer. + generator.Add >(); + + // Creating GAN. + GaussianInitialization gaussian(0, 1); + ens::Adam optimizer(stepSize, // Step size of optimizer. + batchSize, // Batch size. + 0.9, // Exponential decay rate for first moment estimates. + 0.999, // Exponential decay rate for weighted norm estimates. + eps, // Value used to initialize the mean squared gradient parameter. + numIterations, // iterPerCycle// Maximum number of iterations. + tolerance, // Tolerance. + shuffle); // Shuffle. + std::function noiseFunction = []() { + return math::RandNormal(0, 1);}; + GAN >, GaussianInitialization, + std::function > gan(generator, discriminator, + gaussian, noiseFunction, noiseDim, batchSize, generatorUpdateStep, + discriminatorPreTrain, multiplier); + + std::cout << "Training ... " << std::endl; + + const clock_t beginTime = clock(); + // Cycles for monitoring training progress. + for( size_t i = 0; i < cycles; i++) + { + // Training the neural network. For first iteration, weights are random, + // thus using current values as starting point. + gan.Train(mnistDataset, //trainDataset. + optimizer, + ens::PrintLoss(), + ens::ProgressBar(), + ens::Report()); + + optimizer.ResetPolicy() = false; + std::cout << " Model Performance " << + gan.Evaluate(gan.Parameters(), // Parameters of the network. + i, // Index of current input. + batchSize); // Batch size. + } + + std::cout << " Time taken to train -> " << float(clock()-beginTime) / CLOCKS_PER_SEC << "seconds" << std::endl; + + // Let's save the model. + data::Save("./saved_models/ganMnist_25epochs.bin", "ganMnist", gan); + std::cout << "Model saved in mnist_gan/saved_models." << std::endl; + std::cout << "\n"; +} diff --git a/mnist_gan/mnist_gan_generate.cpp b/mnist_gan/mnist_gan_generate.cpp new file mode 100644 index 00000000..358d9975 --- /dev/null +++ b/mnist_gan/mnist_gan_generate.cpp @@ -0,0 +1,116 @@ +#include + +#include +#include + + +#include +#include +#include +#include +#include + +#include + +using namespace mlpack; +using namespace mlpack::ann; + +int main() +{ + size_t discriminatorPreTrain = 5; + size_t batchSize = 5; + size_t noiseDim = 100; + size_t generatorUpdateStep = 1; + size_t numSamples = 10; + double multiplier = 10; + bool loadData = false; + + arma::mat trainData,inputData, validData; + trainData.load("./dataset/mnist_first250_training_4s_and_9s.arm"); + + // If you want to load other mnist data, then uncomment the below lines in the "if" statement to remove and prepare the data for your test. + // if(loadData) + // { + + // inputData.load("File Path"); + + // // Removing the headers. + // inputData = inputData.submat(0, 1, inputData.n_rows - 1, inputData.n_cols - 1); + // inputData /= 255.0; // Note that if you are bringing all the values to 0-1, then in the output csv, you have to multiply all values by 255.0 + + // // Removing the labels. + // inputData = inputData.submat(1, 0, inputData.n_rows - 1, inputData.n_cols - 1); + + // inputData = (inputData - 0.5) * 2; + + // data::Split(inputData, trainData, validData, 0.8); + // } + + arma::arma_rng::set_seed_random(); + + // Define noise function. + std::function noiseFunction = [](){ return math::Random(-8, 8) + + math::RandNormal(0, 1) * 0.01;}; + + // Define generator. + FFN > generator; + + // Define discriminator. + FFN > discriminator; + + // Define GaussinaInitialization. + GaussianInitialization gaussian(0,1); + + // Define GAN class. + GAN >, GaussianInitialization, + std::function > gan(generator, discriminator, + gaussian, noiseFunction, noiseDim, batchSize, generatorUpdateStep, + discriminatorPreTrain, multiplier); + + // Load the saved model. + data::Load("./saved_models/ganMnist_25epochs.bin", "ganMnist", gan); + + /*--------------Sampling-----------------------------------------*/ + + std::cout << "Sampling...." << std::endl; + + // Noise matrix. + arma::mat noise(noiseDim, batchSize); + + // Dimensions of the image. + size_t dim = std::sqrt(trainData.n_rows); + + // Matrix to store the generated data. + arma::mat generatedData(2 * dim, dim * numSamples); + + + for (size_t i = 0; i < numSamples; ++i) + { + arma::mat samples; + + // Create random noise using noise function. + noise.imbue([&]() { return noiseFunction(); }); + + // Pass noise through generator and store output in samples. + gan.Generator().Forward(noise, samples); + + // Reshape and Transpose the samples output. + samples.reshape(dim, dim); + samples = samples.t(); + + // Store the output sample in a dimxdim grid in final output matrix. + generatedData.submat(0, i * dim, dim - 1, i * dim + dim - 1) = samples; + + // Add the image from original train data to compare. + samples = trainData.col(math::RandInt(0, trainData.n_cols)); + samples.reshape(dim, dim); + samples = samples.t(); + generatedData.submat(dim, + i * dim, 2 * dim - 1, i * dim + dim - 1) = samples; + } + // Save the output as csv. + data::Save("./samples_csv_files/sample.csv", generatedData, false, false); + + std::cout << "Output generated!" << std::endl; + std::cout << "\n"; +} diff --git a/mnist_gan/mnist_gan_notebook.ipynb b/mnist_gan/mnist_gan_notebook.ipynb new file mode 100644 index 00000000..dfd66a4e --- /dev/null +++ b/mnist_gan/mnist_gan_notebook.ipynb @@ -0,0 +1,175 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "9e5d496b", + "metadata": {}, + "source": [ + "# Generative Adverserial Networks in mlpack\n", + "\n", + "Deep neural networks are used mainly for supervised learning: classification or regression. Generative Adverserial Networks or GANs, however, use neural networks for a very different purpose: Generative modeling\n", + "\n", + ">A generative adversarial network is a class of machine learning frameworks designed by Ian Goodfellow and his colleagues in 2014. Two neural networks contest with each other in a game. Given a training set, this technique learns to generate new data with the same statistics as the training set. [Wikipedia](https://www.google.com/url?sa=t&rct=j&q=&esrc=s&source=web&cd=&cad=rja&uact=8&ved=2ahUKEwj0oN6n3LXyAhVYIbcAHROSDhQQmhMwKnoECEAQAg&url=https%3A%2F%2Fen.wikipedia.org%2Fwiki%2FGenerative_adversarial_network&usg=AOvVaw09mLp9-GvBO7o23-xIW29P/)\n", + "\n", + "While there are many approaches used for generative modeling, a Generative Adverserial Network takes the following approach: \n", + "\n", + "![GAN Flowchart](https://i.imgur.com/6NMdO9u.png)\n", + "\n", + "There are two neural networks: a *Generator* and a *Discriminator*. The generator generates a \"fake\" sample given a random vector/matrix, and the discriminator attempts to detect whether a given sample is \"real\" (picked from the training data) or \"fake\" (generated by the generator). Training happens in tandem: we train the discriminator for a few epochs, then train the generator for a few epochs, and repeat. This way both the generator and the discriminator get better at doing their jobs. This rather simple approach can lead to some astounding results. The following images ([source](https://machinelearningmastery.com/resources-for-getting-started-with-generative-adversarial-networks/)), for instances, were all generated using GANs:\n", + "\n", + "\"gans_results\"\n", + "\n", + "\n", + "GANs however, can be notoriously difficult to train, and are extremely sensitive to hyperparameters, activation functions and regularization. In this tutorial, we'll train a GAN to generate images of handwritten digits similar to those from the MNIST database.\n", + "\n", + "" + ] + }, + { + "cell_type": "markdown", + "id": "19351cd6", + "metadata": {}, + "source": [ + "# Approach\n", + "Here's what we're going to do:\n", + "\n", + "* Define the problem statement.\n", + "* Create a GAN Training File\n", + " * Load the data.\n", + " * Define the Discriminator network.\n", + " * Define the Generator network.\n", + " * Gaussian Initialization.\n", + " * Define the Noise Function.\n", + " * Create GAN.\n", + " * Train GAN.\n", + "* Create a Image Generator File.\n", + " * Load the trained model.\n", + " * Create random noise.\n", + " * Generate Images using the trained model.\n", + " * Save the outputs in CSV.\n", + "* Look at outputs\n", + " * Convert CSV to Images.\n", + " * If not satisfied with result, train the model with different parameters." + ] + }, + { + "cell_type": "markdown", + "id": "212e3471", + "metadata": {}, + "source": [ + "# Data\n", + "For the datasets, we will be using a small subset of MNIST dataset comprising of only 4's and 9's." + ] + }, + { + "cell_type": "markdown", + "id": "f69e2067", + "metadata": {}, + "source": [ + "# Training File\n", + "Let's create the training file 'mnist_gan.cpp', train the model and save it." + ] + }, + { + "cell_type": "markdown", + "id": "7ff71009", + "metadata": {}, + "source": [ + "# Generate Output CSV\n", + "Let's also create the 'mnist_generate.cpp' file for sampling, generating the outputs from the model and saving the CSV file." + ] + }, + { + "cell_type": "markdown", + "id": "c333fc3f", + "metadata": {}, + "source": [ + "Let's look at the outputs." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "ba6ad46d", + "metadata": {}, + "outputs": [], + "source": [ + "#include \n", + "\n", + "#include \"xwidgets/ximage.hpp\"" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "2dcf56a4", + "metadata": {}, + "outputs": [], + "source": [ + "// Import the generate image script from utils folder.\n", + "#include \"../utils/generateimage.hpp\"" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "04fc2fd6", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e7cb5d5284844723975c2e993a1f2ce1", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "A Jupyter widget with unique id: e7cb5d5284844723975c2e993a1f2ce1" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "// Let's generate image from csv files.\n", + "GenerateImage(\"./samples_csv_files/sample.csv\", \"./samples_posterior/sample.png\");\n", + "auto im = xw::image_from_file(\"./samples_posterior/sample.png\").finalize();\n", + "im" + ] + }, + { + "cell_type": "markdown", + "id": "23b323fd", + "metadata": {}, + "source": [ + "Thus we can see how to leverage mlpack for unsupervised learning, especially GAN's in this example." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1ab57159", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "C++14", + "language": "C++14", + "name": "xcpp14" + }, + "language_info": { + "codemirror_mode": "text/x-c++src", + "file_extension": ".cpp", + "mimetype": "text/x-c++src", + "name": "c++", + "version": "14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/mnist_gan/samples_posterior/sample.png b/mnist_gan/samples_posterior/sample.png new file mode 100644 index 00000000..4f7e4328 Binary files /dev/null and b/mnist_gan/samples_posterior/sample.png differ diff --git a/utils/generateimage.hpp b/utils/generateimage.hpp new file mode 100644 index 00000000..0c20ba1e --- /dev/null +++ b/utils/generateimage.hpp @@ -0,0 +1,95 @@ +// Inside C++ notebook we can use: +// GenerateImage("input.csv", "output.png") +// auto im = xw::image_from_file("output.png").finalize() +// im + +#ifndef C_GENERATE_IMAGE_HPP +#define C_GENERATE_IMAGE_HPP + +#define PY_SSIZE_T_CLEAN +#include +#include + +// Here we use the same arguments as we used in the python script, +// since this is what is passed from the C++ notebook to call the python script. +int GenerateImage(const std::string& inFile, + const std::string& outFile = "output.png") +{ + PyObject *pName, *pModule, *pFunc; + PyObject *pArgs, *pValue; + + // This has to be adapted if you run this on your local system, + // so whenever you call the python script it can find the correct + // module -> PYTHONPATH, on lab.mlpack.org we put all the utility + // functions for plotting uinto the utils folder so we add that path + // to the Python search path. + + Py_Initialize(); + PyRun_SimpleString("import sys"); + PyRun_SimpleString("sys.path.append(\"../utils/\")"); + // Name of the python script without the extension. + pName = PyUnicode_DecodeFSDefault("generateimage"); + + pModule = PyImport_Import(pName); + Py_DECREF(pName); + + if (pModule != NULL) + { + // The Python function from the generateimage.py script + // we like to call - cgenerateimage + pFunc = PyObject_GetAttrString(pModule, "cgenerateimage"); + + if (pFunc && PyCallable_Check(pFunc)) + { + // The number of arguments we pass to the python script. + // inFile, outFile='output.png' + // for the example above it's 2 + pArgs = PyTuple_New(2); + + // Now we have to encode the argument to the correct type + // besides width, height everything else is a string. + // So we can use PyUnicode_FromString. + // If the data is an int we can use PyLong_FromLong, + // see the lines below for an example. + PyObject* pValueinFile = PyUnicode_FromString(inFile.c_str()); + // Here we just set the index of the argument. + PyTuple_SetItem(pArgs, 0, pValueinFile); + + PyObject* pValueoutFile = PyUnicode_FromString(outFile.c_str()); + PyTuple_SetItem(pArgs, 1, pValueoutFile); + + // The rest of the c++ part can stay the same. + + pValue = PyObject_CallObject(pFunc, pArgs); + Py_DECREF(pArgs); + if (pValue != NULL) + { + Py_DECREF(pValue); + } + else + { + Py_DECREF(pFunc); + Py_DECREF(pModule); + PyErr_Print(); + fprintf(stderr,"Call failed.\n"); + return 1; + } + } + else + { + if (PyErr_Occurred()) + PyErr_Print(); + } + + Py_XDECREF(pFunc); + Py_DECREF(pModule); + } + else + { + PyErr_Print(); + return 1; + } + + return 0; +} +#endif \ No newline at end of file diff --git a/utils/generateimage.py b/utils/generateimage.py new file mode 100644 index 00000000..0ca38f64 --- /dev/null +++ b/utils/generateimage.py @@ -0,0 +1,18 @@ +""" +@file generate_images.py +@author Roshan Swain +Generates jpg files from csv. +mlpack is free software; you may redistribute it and/or modify it under the +terms of the 3-clause BSD license. You should have received a copy of the +3-clause BSD license along with mlpack. If not, see +http://www.opensource.org/licenses/BSD-3-Clause for more information. +""" + +from PIL import Image +import numpy as np +import matplotlib.pyplot as plt + +def cgenerateimage(inFile, outFile = "output.png"): + dataset = np.genfromtxt(inFile, delimiter = ',', dtype = np.uint8) + im = Image.fromarray(dataset) + im.save(outFile)