diff --git a/generating_hand_written_digits_mnist_with_gan/generating_handwritten_digits_mnist_with_gan.ipynb b/generating_hand_written_digits_mnist_with_gan/generating_handwritten_digits_mnist_with_gan.ipynb new file mode 100644 index 00000000..1a335bab --- /dev/null +++ b/generating_hand_written_digits_mnist_with_gan/generating_handwritten_digits_mnist_with_gan.ipynb @@ -0,0 +1,404 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 34, + "id": "c23013fc", + "metadata": {}, + "outputs": [], + "source": [ + "#include\n", + "#include\n", + "#include\n", + "#include\n", + "#include\n", + "#include" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "dccd0b37", + "metadata": {}, + "outputs": [], + "source": [ + "#include" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "ad7ab4c6", + "metadata": {}, + "outputs": [], + "source": [ + "using namespace mlpack;\n", + "using namespace mlpack::ann;\n", + "using namespace arma;\n", + "using namespace std;\n", + "using namespace ens;" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "7ed7a903", + "metadata": {}, + "outputs": [], + "source": [ + "arma::Row getLabels(arma::mat predOut){\n", + " arma::Row predLabels(predOut.n_cols);\n", + " for( arma::uword i = 0; i < predOut.n_cols; ++i){\n", + " predLabels(i) = predOut.col(i).index_max();\n", + " }\n", + " return predLabels;\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "131fe748", + "metadata": {}, + "outputs": [], + "source": [ + "constexpr double RATIO = 0.1;\n", + "constexpr int MAX_ITERATIONS = 0;\n", + "constexpr double STEP_SIZE = 1.2e-3;\n", + "constexpr int BATCH_SIZE = 50;" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "697a8d63", + "metadata": {}, + "outputs": [], + "source": [ + "arma::mat dataset;" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "64ee4e72", + "metadata": {}, + "outputs": [], + "source": [ + "data::Load(\"/home/viole/swaingotnochill/examples/generating_hand_written_digits_mnist_with_gan/digit-recognizer/train.csv\", dataset, true);" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "d8a3d8da", + "metadata": {}, + "outputs": [], + "source": [ + "arma::mat train, valid;" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "da10fdac", + "metadata": {}, + "outputs": [], + "source": [ + "data::Split(dataset, train, valid, RATIO);" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "a6a484e7", + "metadata": {}, + "outputs": [], + "source": [ + "const arma::mat trainX = train.submat(1, 0, train.n_rows - 1, train.n_cols - 1);\n", + "const arma::mat validX = valid.submat(1, 0, valid.n_rows - 1, valid.n_cols - 1);" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "db267de6", + "metadata": {}, + "outputs": [], + "source": [ + "const arma::mat trainY = train.row(0);\n", + "const arma::mat validY = valid.row(0);" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "7eb2907b", + "metadata": {}, + "outputs": [], + "source": [ + "using namespace mlpack::ann;\n", + "FFN, RandomInitialization> model;" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "id": "ffc66454", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "In file included from input_line_7:5:\n", + "\u001b[1m/home/viole/anaconda3/envs/notebook/include/mlpack/methods/ann/ffn.hpp:290:36: \u001b[0m\u001b[0;1;31merror: \u001b[0m\u001b[1mno member named 'push_back' in\n", + " 'std::vector,\n", + " arma::Mat > *,\n", + " mlpack::ann::AdaptiveMeanPooling,\n", + " arma::Mat > *, mlpack::ann::Add,\n", + " arma::Mat > *, mlpack::ann::AddMerge,\n", + " arma::Mat> *, mlpack::ann::AlphaDropout,\n", + " arma::Mat > *,\n", + " mlpack::ann::AtrousConvolution,\n", + " mlpack::ann::NaiveConvolution,\n", + " mlpack::ann::NaiveConvolution,\n", + " arma::Mat, arma::Mat > *,\n", + " mlpack::ann::BaseLayer, arma::Mat > *,\n", + " mlpack::ann::BaseLayer, arma::Mat > *,\n", + " mlpack::ann::BaseLayer,\n", + " arma::Mat > *,\n", + " mlpack::ann::BaseLayer, arma::Mat > *,\n", + " mlpack::ann::BaseLayer, arma::Mat > *,\n", + " mlpack::ann::BatchNorm, arma::Mat > *,\n", + " mlpack::ann::BilinearInterpolation,\n", + " arma::Mat > *, mlpack::ann::CELU,\n", + " arma::Mat > *, mlpack::ann::Concat,\n", + " arma::Mat> *, mlpack::ann::Concatenate,\n", + " arma::Mat > *,\n", + " mlpack::ann::ConcatPerformance,\n", + " arma::Mat >, arma::Mat, arma::Mat > *,\n", + " mlpack::ann::Constant, arma::Mat > *,\n", + " mlpack::ann::Convolution,\n", + " mlpack::ann::NaiveConvolution,\n", + " mlpack::ann::NaiveConvolution,\n", + " arma::Mat, arma::Mat > *,\n", + " mlpack::ann::CReLU, arma::Mat > *,\n", + " mlpack::ann::DropConnect, arma::Mat > *,\n", + " mlpack::ann::Dropout, arma::Mat > *,\n", + " mlpack::ann::ELU, arma::Mat > *,\n", + " mlpack::ann::FastLSTM, arma::Mat > *,\n", + " mlpack::ann::FlexibleReLU, arma::Mat > *,\n", + " mlpack::ann::GRU, arma::Mat > *,\n", + " mlpack::ann::HardTanH, arma::Mat > *,\n", + " mlpack::ann::Join, arma::Mat > *,\n", + " mlpack::ann::LayerNorm, arma::Mat > *,\n", + " mlpack::ann::LeakyReLU, arma::Mat > *,\n", + " mlpack::ann::Linear, arma::Mat,\n", + " mlpack::ann::NoRegularizer> *,\n", + " mlpack::ann::LinearNoBias, arma::Mat,\n", + " mlpack::ann::NoRegularizer> *,\n", + " mlpack::ann::LogSoftMax, arma::Mat > *,\n", + " mlpack::ann::Lookup, arma::Mat > *,\n", + " mlpack::ann::LSTM, arma::Mat > *,\n", + " mlpack::ann::MaxPooling, arma::Mat > *,\n", + " mlpack::ann::MeanPooling, arma::Mat > *,\n", + " mlpack::ann::MiniBatchDiscrimination,\n", + " arma::Mat > *,\n", + " mlpack::ann::MultiplyConstant, arma::Mat >\n", + " *, mlpack::ann::MultiplyMerge, arma::Mat>\n", + " *, mlpack::ann::NegativeLogLikelihood,\n", + " arma::Mat > *, mlpack::ann::NoisyLinear,\n", + " arma::Mat > *, mlpack::ann::Padding,\n", + " arma::Mat > *, mlpack::ann::PReLU,\n", + " arma::Mat > *, mlpack::ann::Softmax,\n", + " arma::Mat > *,\n", + " mlpack::ann::SpatialDropout, arma::Mat >\n", + " *,\n", + " mlpack::ann::TransposedConvolution,\n", + " mlpack::ann::NaiveConvolution,\n", + " mlpack::ann::NaiveConvolution,\n", + " arma::Mat, arma::Mat > *,\n", + " mlpack::ann::WeightNorm, arma::Mat> *,\n", + " boost::variant,\n", + " arma::Mat, mlpack::ann::NoRegularizer> *,\n", + " mlpack::ann::Glimpse, arma::Mat > *,\n", + " mlpack::ann::Highway, arma::Mat> *,\n", + " mlpack::ann::MultiheadAttention,\n", + " arma::Mat, mlpack::ann::NoRegularizer> *,\n", + " mlpack::ann::Recurrent, arma::Mat> *,\n", + " mlpack::ann::RecurrentAttention, arma::Mat\n", + " > *, mlpack::ann::ReinforceNormal,\n", + " arma::Mat > *,\n", + " mlpack::ann::Reparametrization, arma::Mat\n", + " > *, mlpack::ann::Select, arma::Mat > *,\n", + " mlpack::ann::Sequential, arma::Mat, false>\n", + " *, mlpack::ann::Sequential, arma::Mat,\n", + " true> *, mlpack::ann::Subview, arma::Mat >\n", + " *, mlpack::ann::VRClassReward, arma::Mat >\n", + " *, mlpack::ann::VirtualBatchNorm,\n", + " arma::Mat > *, mlpack::ann::RBF,\n", + " arma::Mat, mlpack::ann::GaussianFunction> *,\n", + " mlpack::ann::BaseLayer, arma::Mat > *,\n", + " mlpack::ann::PositionalEncoding, arma::Mat\n", + " > *> >,\n", + " std::allocator,\n", + " arma::Mat > *,\n", + " mlpack::ann::AdaptiveMeanPooling,\n", + " arma::Mat > *, mlpack::ann::Add,\n", + " arma::Mat > *, mlpack::ann::AddMerge,\n", + " arma::Mat> *, mlpack::ann::AlphaDropout,\n", + " arma::Mat > *,\n", + " mlpack::ann::AtrousConvolution,\n", + " mlpack::ann::NaiveConvolution,\n", + " mlpack::ann::NaiveConvolution,\n", + " arma::Mat, arma::Mat > *,\n", + " mlpack::ann::BaseLayer, arma::Mat > *,\n", + " mlpack::ann::BaseLayer, arma::Mat > *,\n", + " mlpack::ann::BaseLayer,\n", + " arma::Mat > *,\n", + " mlpack::ann::BaseLayer, arma::Mat > *,\n", + " mlpack::ann::BaseLayer, arma::Mat > *,\n", + " mlpack::ann::BatchNorm, arma::Mat > *,\n", + " mlpack::ann::BilinearInterpolation,\n", + " arma::Mat > *, mlpack::ann::CELU,\n", + " arma::Mat > *, mlpack::ann::Concat,\n", + " arma::Mat> *, mlpack::ann::Concatenate,\n", + " arma::Mat > *,\n", + " mlpack::ann::ConcatPerformance,\n", + " arma::Mat >, arma::Mat, arma::Mat > *,\n", + " mlpack::ann::Constant, arma::Mat > *,\n", + " mlpack::ann::Convolution,\n", + " mlpack::ann::NaiveConvolution,\n", + " mlpack::ann::NaiveConvolution,\n", + " arma::Mat, arma::Mat > *,\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " mlpack::ann::CReLU, arma::Mat > *,\n", + " mlpack::ann::DropConnect, arma::Mat > *,\n", + " mlpack::ann::Dropout, arma::Mat > *,\n", + " mlpack::ann::ELU, arma::Mat > *,\n", + " mlpack::ann::FastLSTM, arma::Mat > *,\n", + " mlpack::ann::FlexibleReLU, arma::Mat > *,\n", + " mlpack::ann::GRU, arma::Mat > *,\n", + " mlpack::ann::HardTanH, arma::Mat > *,\n", + " mlpack::ann::Join, arma::Mat > *,\n", + " mlpack::ann::LayerNorm, arma::Mat > *,\n", + " mlpack::ann::LeakyReLU, arma::Mat > *,\n", + " mlpack::ann::Linear, arma::Mat,\n", + " mlpack::ann::NoRegularizer> *,\n", + " mlpack::ann::LinearNoBias, arma::Mat,\n", + " mlpack::ann::NoRegularizer> *,\n", + " mlpack::ann::LogSoftMax, arma::Mat > *,\n", + " mlpack::ann::Lookup, arma::Mat > *,\n", + " mlpack::ann::LSTM, arma::Mat > *,\n", + " mlpack::ann::MaxPooling, arma::Mat > *,\n", + " mlpack::ann::MeanPooling, arma::Mat > *,\n", + " mlpack::ann::MiniBatchDiscrimination,\n", + " arma::Mat > *,\n", + " mlpack::ann::MultiplyConstant, arma::Mat >\n", + " *, mlpack::ann::MultiplyMerge, arma::Mat>\n", + " *, mlpack::ann::NegativeLogLikelihood,\n", + " arma::Mat > *, mlpack::ann::NoisyLinear,\n", + " arma::Mat > *, mlpack::ann::Padding,\n", + " arma::Mat > *, mlpack::ann::PReLU,\n", + " arma::Mat > *, mlpack::ann::Softmax,\n", + " arma::Mat > *,\n", + " mlpack::ann::SpatialDropout, arma::Mat >\n", + " *,\n", + " mlpack::ann::TransposedConvolution,\n", + " mlpack::ann::NaiveConvolution,\n", + " mlpack::ann::NaiveConvolution,\n", + " arma::Mat, arma::Mat > *,\n", + " mlpack::ann::WeightNorm, arma::Mat> *,\n", + " boost::variant,\n", + " arma::Mat, mlpack::ann::NoRegularizer> *,\n", + " mlpack::ann::Glimpse, arma::Mat > *,\n", + " mlpack::ann::Highway, arma::Mat> *,\n", + " mlpack::ann::MultiheadAttention,\n", + " arma::Mat, mlpack::ann::NoRegularizer> *,\n", + " mlpack::ann::Recurrent, arma::Mat> *,\n", + " mlpack::ann::RecurrentAttention, arma::Mat\n", + " > *, mlpack::ann::ReinforceNormal,\n", + " arma::Mat > *,\n", + " mlpack::ann::Reparametrization, arma::Mat\n", + " > *, mlpack::ann::Select, arma::Mat > *,\n", + " mlpack::ann::Sequential, arma::Mat, false>\n", + " *, mlpack::ann::Sequential, arma::Mat,\n", + " true> *, mlpack::ann::Subview, arma::Mat >\n", + " *, mlpack::ann::VRClassReward, arma::Mat >\n", + " *, mlpack::ann::VirtualBatchNorm,\n", + " arma::Mat > *, mlpack::ann::RBF,\n", + " arma::Mat, mlpack::ann::GaussianFunction> *,\n", + " mlpack::ann::BaseLayer, arma::Mat > *,\n", + " mlpack::ann::PositionalEncoding, arma::Mat\n", + " > *> > > >'\u001b[0m\n", + " void Add(Args... args) { network.push_back(new LayerType(args...)); }\n", + "\u001b[0;1;32m ~~~~~~~ ^\n", + "\u001b[0m\u001b[1minput_line_46:2:8: \u001b[0m\u001b[0;1;30mnote: \u001b[0min instantiation of function template specialization\n", + " 'mlpack::ann::FFN,\n", + " arma::Mat >,\n", + " mlpack::ann::RandomInitialization>::Add,\n", + " mlpack::ann::NaiveConvolution,\n", + " mlpack::ann::NaiveConvolution,\n", + " arma::Mat, arma::Mat >, int, int, int, int, int,\n", + " int, int, int, int, int>' requested here\u001b[0m\n", + " model.Add> (1,6,5,5,1,1,0,0,28,28);\n", + "\u001b[0;1;32m ^\n", + "\u001b[0m" + ] + }, + { + "ename": "Interpreter Error", + "evalue": "", + "output_type": "error", + "traceback": [ + "Interpreter Error: " + ] + } + ], + "source": [ + "model.Add> (1,6,5,5,1,1,0,0,28,28);\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "03d0646a", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "C++14", + "language": "C++14", + "name": "xcpp14" + }, + "language_info": { + "codemirror_mode": "text/x-c++src", + "file_extension": ".cpp", + "mimetype": "text/x-c++src", + "name": "c++", + "version": "14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}