diff --git a/.gitignore b/.gitignore index 6fc3b1f2..4708e5bb 100644 --- a/.gitignore +++ b/.gitignore @@ -13,5 +13,6 @@ data/* *.jpg *.png *.txt +*.bin .travis/configs.hpp Testing/* diff --git a/CMakeLists.txt b/CMakeLists.txt index d5a32184..0a464a49 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -215,7 +215,9 @@ set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin/) # Recurse into each model. set(DIRS utils/ + ensmallen_utils/ dataloader/ + models/ tests/ ) diff --git a/dataloader/dataloader.hpp b/dataloader/dataloader.hpp index d389e58f..28c1cb44 100644 --- a/dataloader/dataloader.hpp +++ b/dataloader/dataloader.hpp @@ -20,6 +20,7 @@ #include #include #include +#include #ifndef MODELS_DATALOADER_HPP diff --git a/dataloader/dataloader_impl.hpp b/dataloader/dataloader_impl.hpp index e0069f4c..5538cd1d 100644 --- a/dataloader/dataloader_impl.hpp +++ b/dataloader/dataloader_impl.hpp @@ -383,7 +383,7 @@ template< // We use to endls here as one of them will be replaced by print // command below. - Log::Info << "Found " << imagesDirectory.size() << " belonging to " << + mlpack::Log::Info << "Found " << imagesDirectory.size() << " belonging to " << label << " class." << std::endl << std::endl; size_t loadedImages = 0; @@ -395,7 +395,6 @@ template< { continue; } - mlpack::data::ImageInfo imageInfo(imageWidth, imageHeight, imageDepth); // Load the image. @@ -406,12 +405,14 @@ template< mlpack::data::Load(imageName.string(), image, imageInfo); // Add object to training set. - dataset.insert_cols(0, image); - labels.insert_cols(0, arma::vec(1).fill(label)); - - loadedImages++; - mlpack::Log::Info << "Loaded " << loadedImages << " out of " << - imagesDirectory.size() << "\r" << std::endl; + if (image.n_rows == dataset.n_rows || dataset.n_elem == 0) + { + labels.insert_cols(0, arma::vec(1).fill(label)); + dataset.insert_cols(0, image); + loadedImages++; + mlpack::Log::Info << "Loaded " << loadedImages << " out of " << + imagesDirectory.size() << "\r" << std::endl; + } } } @@ -485,6 +486,7 @@ template< validationData.n_rows - 1); augmentations.Transform(trainFeatures, imageWidth, imageHeight, imageDepth); + augmentations.Transform(validFeatures, imageWidth, imageHeight, imageDepth); mlpack::Log::Info << "Found " << totalClasses << " classes." << std::endl; diff --git a/dataloader/preprocessor.hpp b/dataloader/preprocessor.hpp index d41d7b23..85887a0b 100644 --- a/dataloader/preprocessor.hpp +++ b/dataloader/preprocessor.hpp @@ -55,6 +55,44 @@ class PreProcessor { // Nothing to do here. Added to match the rest of the codebase. } + + /** + * Converts image to channel first format used in PyTorch. Performs the same function + * as torch.transforms.ToTensor(). + * + * @param trainFeatures Input features that will be converted into channel first format. + * @param imageWidth Width of the image in dataset. + * @param imageHeight Height of the image in dataset. + * @param imageDepth Depth / Number of channels of the image in dataset. + */ + static void ChannelFirstImages(DatasetX& trainFeatures, + const size_t imageWidth, + const size_t imageHeight, + const size_t imageDepth, + bool normalize = true) + { + for (size_t idx = 0; idx < trainFeatures.n_cols; idx++) + { + // Create a copy of the current image so that the image isn't affected. + arma::cube inputTemp(trainFeatures.col(idx).memptr(), 3, 224, 224); + + size_t currentOffset = 0; + for (size_t i = 0; i < inputTemp.n_slices; i++) + { + trainFeatures.col(idx)(arma::span(currentOffset, currentOffset + + inputTemp.slice(i).n_elem - 1), arma::span()) = + arma::vectorise(inputTemp.slice(i).t()); + currentOffset += inputTemp.slice(i).n_elem; + } + } + + if (normalize) + { + // Convert each element to uint8 first and then divide by 255. + for (size_t i = 0; i < trainFeatures.n_elem; i++) + trainFeatures(i) = ((uint8_t)(trainFeatures(i)) / 255.0); + } + } }; #endif diff --git a/ensmallen_utils/CMakeLists.txt b/ensmallen_utils/CMakeLists.txt new file mode 100644 index 00000000..57b77395 --- /dev/null +++ b/ensmallen_utils/CMakeLists.txt @@ -0,0 +1,16 @@ +cmake_minimum_required(VERSION 3.1.0 FATAL_ERROR) +project(ensmallen_utils) + +set(DIR_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/) + +set(SOURCES + print_metric.hpp + periodic_save.hpp) + +foreach(file ${SOURCES}) + set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file}) +endforeach() + +# Append sources (with directory name) to list of all models sources (used at +# the parent scope). +set(DIRS ${DIRS} ${DIR_SRCS} PARENT_SCOPE) diff --git a/ensmallen_utils/periodic_save.hpp b/ensmallen_utils/periodic_save.hpp new file mode 100644 index 00000000..ad10a832 --- /dev/null +++ b/ensmallen_utils/periodic_save.hpp @@ -0,0 +1,102 @@ +/** + * @file utils.hpp + * @author Kartik Dutt + * + * Definition of Periodic Save utility functions. + * + * mlpack is free software; you may redistribute it and/or modify it under the + * terms of the 3-clause BSD license. You should have received a copy of the + * 3-clause BSD license along with mlpack. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ + +#ifndef ENSMALLEN_CALLBACKS_PERIODIC_SAVE_HPP +#define ENSMALLEN_CALLBACKS_PERIODIC_SAVE_HPP + +#include +#include + +namespace ens { + +/** + * Saves model being trained periodically. + * + * @tparam ANNType Type of model which will be used for evaluating metric. + */ +template +class PeriodicSave +{ + public: + /** + * Constructor for PeriodicSave class. + * + * @param network Network type which will be saved periodically. + * @param filePath Base path / folder where weights will be saved. + * @param modelPrefix Weights will be stored as + * modelPrefix_epoch_loss.bin. + * @param period Period after which the model will be saved. + * @param silent Boolean to determine whether or not to print saving + * of model. + * @param output Outputstream where output will be directed. + */ + PeriodicSave(AnnType& network, + const std::string filePath = "./", + const std::string modelPrefix = "model", + const size_t period = 1, + const bool silent = false, + std::ostream& output = arma::get_cout_stream()) : + network(network), + filePath(filePath), + modelPrefix(modelPrefix), + period(period), + silent(silent), + output(output) + { + // Nothing to do here. + } + + template + bool EndEpoch(OptimizerType& /* optimizer */, + FunctionType& /* function */, + const MatType& /* coordinates */, + const size_t epoch, + const double objective) + { + if (epoch % period == 0) + { + std::string objectiveString = std::to_string(objective); + std::replace(objectiveString.begin(), objectiveString.end(), '.', '_'); + std::string modelName = modelPrefix + "_" + std::to_string(epoch) + "_" + + objectiveString; + mlpack::data::Save(filePath + modelName + ".bin", modelPrefix, network); + if (!silent) + output << "Model saved as " << modelName << std::endl; + } + + return false; + } + + private: + // Reference to the model which will be used for evaluated using the metric. + AnnType& network; + + // Locally held string that depicts path for saving the model. + std::string filePath; + + // Locally held string that depicts the prefix name of model being trained. + std::string modelPrefix; + + // Period to save the model. + size_t period; + + // Locally held boolean to determine whether to print success / failure output + // when model is saved. + bool silent; + + // The output stream that all data is to be sent to; example: std::cout. + std::ostream& output; +}; + +} // namespace ens + +#endif diff --git a/ensmallen_utils/print_metric.hpp b/ensmallen_utils/print_metric.hpp new file mode 100644 index 00000000..b5099c55 --- /dev/null +++ b/ensmallen_utils/print_metric.hpp @@ -0,0 +1,107 @@ +/** + * @file utils.hpp + * @author Kartik Dutt + * + * Definition of PrintMetric class. + * + * mlpack is free software; you may redistribute it and/or modify it under the + * terms of the 3-clause BSD license. You should have received a copy of the + * 3-clause BSD license along with mlpack. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ + +#ifndef ENSMALLEN_CALLBACKS_PRINT_METRIC_HPP +#define ENSMALLEN_CALLBACKS_PRINT_METRIC_HPP + +#include +#include + +namespace ens { + +/** + * Prints metric on training / validation set. + * + * @tparam ANNType Type of model which will be used for evaluating metric. + * @tparam MetricType Metric class which must have static `Evaluate` function + * that will be called at the end of the epoch. + * @tparam InputType Arma type of dataset features. + * @tparam OutputType Arma type of dataset labels. + */ +template +class PrintMetric +{ + public: + /** + * Constructor for PrintMetric class. + * @param network Network type which will be saved periodically. + * @param features Input features on which model will be evaluated. + * @param responses Ground truth label for the mdoel. + * @param metricName Metric name which will be printed after each epoch. + * @param trainData Boolean to determine whether dataset corresponds to + * training data or validation data. + * @param output Outputstream where output will be directed. + */ + PrintMetric(AnnType &network, + const InputType &features, + const OutputType &responses, + const std::string metricName = "metric", + const bool trainData = false, + std::ostream &output = arma::get_cout_stream()) : + network(network), + features(features), + responses(responses), + metricName(metricName), + trainData(trainData), + output(output) + { + // Nothing to do here. + } + + template + bool EndEpoch(OptimizerType& /* optimizer */, + FunctionType& /* function */, + const MatType& /* coordinates */, + const size_t /* epoch */, + const double /* objective */) + { + OutputType predictions; + network.Predict(features, predictions); + const double localObjective = MetricType::Evaluate(predictions, responses); + if (!std::isnan(localObjective)) + { + std::string outputString = (trainData == true) ? "Train " : "Validation "; + outputString = outputString + metricName + " : " + + std::to_string(localObjective); + output << outputString << std::endl; + } + return false; + } + + private: + // Reference to the model which will be used for evaluated using the metric. + AnnType& network; + + // Dataset which will be used for evaluating the metric. + InputType features; + + // Dataset labels / predictions that will be used for evaluating the dataset. + OutputType responses; + + // Locally held string that depicts the name of the metric. + std::string metricName; + + // Locally held boolean to determin whether evaluation is done on train data + // or validation data. + bool trainData; + + // The output stream that all data is to be sent to; example: std::cout. + std::ostream& output; +}; + +} // namespace ens + +#endif diff --git a/models/CMakeLists.txt b/models/CMakeLists.txt new file mode 100644 index 00000000..c4bd5a8c --- /dev/null +++ b/models/CMakeLists.txt @@ -0,0 +1,16 @@ +cmake_minimum_required(VERSION 3.1.0 FATAL_ERROR) +project(models) + +add_subdirectory(darknet) + +# Add directory name to sources. +set(DIR_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/) + +foreach(file ${SOURCES}) + set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file}) +endforeach() + +set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/layer.hpp) +# Append sources (with directory name) to list of all models sources (used at +# the parent scope). +set(DIRS ${DIRS} ${DIR_SRCS} PARENT_SCOPE) diff --git a/models/darknet/CMakeLists.txt b/models/darknet/CMakeLists.txt new file mode 100644 index 00000000..92bbdf82 --- /dev/null +++ b/models/darknet/CMakeLists.txt @@ -0,0 +1,16 @@ +cmake_minimum_required(VERSION 3.1.0 FATAL_ERROR) +project(darknet) + +set(DIR_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/) +include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../") + +set(SOURCES + darknet.hpp + darknet_impl.hpp +) + +foreach(file ${SOURCES}) + set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file}) +endforeach() + +set(DIRS ${DIRS} ${DIR_SRCS} PARENT_SCOPE) diff --git a/models/darknet/darknet.hpp b/models/darknet/darknet.hpp new file mode 100644 index 00000000..7b24401c --- /dev/null +++ b/models/darknet/darknet.hpp @@ -0,0 +1,300 @@ +/** + * @file darknet.hpp + * @author Kartik Dutt + * + * Definition of DarkNet models. + * + * For more information, kindly refer to the following paper. + * + * Paper for DarkNet-19. + * + * @code + * @article{Redmon2016, + * author = {Joseph Redmon, Ali Farhadi}, + * title = {YOLO9000 : Better, Faster, Stronger}, + * year = {2016}, + * url = {https://pjreddie.com/media/files/papers/YOLO9000.pdf} + * } + * @endcode + * + * Paper for DarkNet-53. + * + * @code + * @article{Redmon2016, + * author = {Joseph Redmon, Ali Farhadi}, + * title = {YOLOv3 : An Incremental Improvement}, + * year = {2019}, + * url = {https://pjreddie.com/media/files/papers/YOLOv3.pdf} + * } + * @endcode + * + * mlpack is free software; you may redistribute it and/or modify it under the + * terms of the 3-clause BSD license. You should have received a copy of the + * 3-clause BSD license along with mlpack. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ + +#ifndef MODELS_DARKNET_HPP +#define MODELS_DARKNET_HPP + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace mlpack { +namespace ann /** Artificial Neural Network. */{ + +/** + * Definition of a DarkNet CNN. + * + * @tparam OutputLayerType The output layer type used to evaluate the network. + * @tparam InitializationRuleType Rule used to initialize the weight matrix. + * @tparam DaknetVer Version of DarkNet. + */ +template< + typename OutputLayerType = CrossEntropyError<>, + typename InitializationRuleType = RandomInitialization, + size_t DarkNetVersion = 19 +> +class DarkNet +{ + public: + //! Create the DarkNet model. + DarkNet(); + + /** + * DarkNet constructor intializes input shape and number of classes. + * + * @param inputChannels Number of input channels of the input image. + * @param inputWidth Width of the input image. + * @param inputHeight Height of the input image. + * @param numClasses Optional number of classes to classify images into, + * only to be specified if includeTop is true. + * @param weights One of 'none', 'imagenet'(pre-training on ImageNet) or path to weights. + * @param includeTop Must be set to true if weights are set. + */ + DarkNet(const size_t inputChannel, + const size_t inputWidth, + const size_t inputHeight, + const size_t numClasses = 1000, + const std::string& weights = "none", + const bool includeTop = true); + + /** + * DarkNet constructor intializes input shape and number of classes. + * + * @param inputShape A three-valued tuple indicating input shape. + * First value is number of channels (channels-first). + * Second value is input height. Third value is input width. + * @param numClasses Optional number of classes to classify images into, + * only to be specified if includeTop is true. + * @param weights One of 'none', 'imagenet'(pre-training on ImageNet) or path to weights. + */ + DarkNet(const std::tuple inputShape, + const size_t numClasses = 1000, + const std::string& weights = "none", + const bool includeTop = true); + + //! Get Layers of the model. + FFN& GetModel() { return darkNet; } + + //! Load weights into the model. + void LoadModel(const std::string& filePath); + + //! Save weights for the model. + void SaveModel(const std::string& filePath); + + private: + /** + * Adds Convolution Block. + * + * @tparam SequentialType Layer type in which convolution block will + * be added. + * + * @param inSize Number of input maps. + * @param outSize Number of output maps. + * @param kernelWidth Width of the filter/kernel. + * @param kernelHeight Height of the filter/kernel. + * @param strideWidth Stride of filter application in the x direction. + * @param strideHeight Stride of filter application in the y direction. + * @param padW Padding width of the input. + * @param padH Padding height of the input. + * @param batchNorm Boolean to determine whether a batch normalization + * layer is added. + * @param negativeSlope Negative slope hyper-parameter for LeakyReLU. + * @param baseLayer Layer in which Convolution block will be added, if + * NULL added to darkNet FFN. + */ + template> + void ConvolutionBlock(const size_t inSize, + const size_t outSize, + const size_t kernelWidth, + const size_t kernelHeight, + const size_t strideWidth = 1, + const size_t strideHeight = 1, + const size_t padW = 0, + const size_t padH = 0, + const bool batchNorm = true, + const double negativeSlope = 1e-1, + SequentialType* baseLayer = NULL) + { + Sequential<>* bottleNeck = new Sequential<>(); + bottleNeck->Add(new Convolution<>(inSize, outSize, kernelWidth, + kernelHeight, strideWidth, strideHeight, padW, padH, inputWidth, + inputHeight)); + + // Update inputWidth and input Height. + mlpack::Log::Info << "Conv Layer. "; + mlpack::Log::Info << "(" << inputWidth << ", " << inputHeight << + ", " << inSize << ") ----> "; + + inputWidth = ConvOutSize(inputWidth, kernelWidth, strideWidth, padW); + inputHeight = ConvOutSize(inputHeight, kernelHeight, strideHeight, padH); + mlpack::Log::Info << "(" << inputWidth << ", " << inputHeight << + ", " << outSize << ")" << std::endl; + + if (batchNorm) + bottleNeck->Add(new BatchNorm<>(outSize, 1e-5, false)); + + bottleNeck->Add(new LeakyReLU<>(negativeSlope)); + + if (baseLayer != NULL) + baseLayer->Add(bottleNeck); + else + darkNet.Add(bottleNeck); + } + + /** + * Adds Pooling Block. + * + * @param factor The factor by which input dimensions will be divided. + * @param type One of "max" or "mean". Determines whether add mean pooling + * layer or max pooling layer. + */ + void PoolingBlock(const size_t factor = 2, + const std::string type = "max") + { + if (type == "max") + { + darkNet.Add(new AdaptiveMaxPooling<>(std::ceil(inputWidth * 1.0 / factor), + std::ceil(inputHeight * 1.0 / factor))); + } + else + { + darkNet.Add(new AdaptiveMeanPooling<>(std::ceil(inputWidth * 1.0 / + factor), std::ceil(inputHeight * 1.0 / factor))); + } + + mlpack::Log::Info << "Pooling Layer. "; + mlpack::Log::Info << "(" << inputWidth << ", " << inputHeight << + ") ----> "; + + // Update inputWidth and inputHeight. + inputWidth = std::ceil(inputWidth * 1.0 / factor); + inputHeight = std::ceil(inputHeight * 1.0 / factor); + mlpack::Log::Info << "(" << inputWidth << ", " << inputHeight << + ")" << std::endl; + } + + /** + * Adds bottleneck block for DarkNet 19. + * + * It's represented as: + * ConvolutionLayer(inputChannel, inputChannel * 2, stride) + * | + * ConvolutionLayer(inputChannel * 2, inputChannel, 1) + * | + * ConvolutionLayer(inputChannel, inputChannel * 2, stride) + * + * @param inputChannel Input channel in the convolution block. + * @param kernelWidth Width of the filter/kernel. + * @param kernelHeight Height of the filter/kernel. + * @param padWidth Padding in convolutional layer. + * @param padHeight Padding in convolutional layer. + */ + void DarkNet19SequentialBlock(const size_t inputChannel, + const size_t kernelWidth, + const size_t kernelHeight, + const size_t padWidth, + const size_t padHeight) + { + ConvolutionBlock(inputChannel, inputChannel * 2, + kernelWidth, kernelHeight, 1, 1, padWidth, padHeight, true); + ConvolutionBlock(inputChannel * 2, inputChannel, + 1, 1, 1, 1, 0, 0, true); + ConvolutionBlock(inputChannel, inputChannel * 2, + kernelWidth, kernelHeight, 1, 1, padWidth, padHeight, true); + } + + /** + * Adds residual bottleneck block for DarkNet 53. + * + * @param inputChannel Input channel in the bottle-neck. + * @param kernelWidth Width of the filter/kernel. + * @param kernelHeight Height of the filter/kernel. + * @param padWidth Padding in convolutional layer. + * @param padHeight Padding in convolutional layer. + */ + void DarkNet53ResidualBlock(const size_t inputChannel, + const size_t kernelWidth = 3, + const size_t kernelHeight = 3, + const size_t padWidth = 1, + const size_t padHeight = 1) + { + mlpack::Log::Info << "Residual Block Begin." << std::endl; + Residual<>* residualBlock = new Residual<>(); + ConvolutionBlock(inputChannel, inputChannel / 2, + 1, 1, 1, 1, 0, 0, true, 1e-2, residualBlock); + ConvolutionBlock(inputChannel / 2, inputChannel, kernelWidth, + kernelHeight, 1, 1, padWidth, padWidth, true, 1e-2, residualBlock); + darkNet.Add(residualBlock); + mlpack::Log::Info << "Residual Block end." << std::endl; + } + + /** + * Return the convolution output size. + * + * @param size The size of the input (row or column). + * @param k The size of the filter (width or height). + * @param s The stride size (x or y direction). + * @param padding The size of the padding (width or height) on one side. + * @return The convolution output size. + */ + size_t ConvOutSize(const size_t size, + const size_t k, + const size_t s, + const size_t padding) + { + return std::floor(size + 2 * padding - k) / s + 1; + } + + //! Locally stored DarkNet Model. + FFN darkNet; + + //! Locally stored width of the image. + size_t inputWidth; + + //! Locally stored height of the image. + size_t inputHeight; + + //! Locally stored number of channels in the image. + size_t inputChannel; + + //! Locally stored number of output classes. + size_t numClasses; + + //! Locally stored type of pre-trained weights. + std::string weights; +}; // DarkNet class. + +} // namespace ann +} // namespace mlpack + +# include "darknet_impl.hpp" + +#endif diff --git a/models/darknet/darknet_impl.hpp b/models/darknet/darknet_impl.hpp new file mode 100644 index 00000000..ec2567b7 --- /dev/null +++ b/models/darknet/darknet_impl.hpp @@ -0,0 +1,187 @@ +/** + * @file darknet_impl.hpp + * @author Kartik Dutt + * + * Implementation of DarkNet using mlpack. + * + * mlpack is free software; you may redistribute it and/or modify it under the + * terms of the 3-clause BSD license. You should have received a copy of the + * 3-clause BSD license along with mlpack. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ +#ifndef MODELS_DARKNET_IMPL_HPP +#define MODELS_DARKNET_IMPL_HPP + +#include "darknet.hpp" + +namespace mlpack { +namespace ann { + +template< + typename OutputLayerType, + typename InitializationRuleType, + size_t DarkNetVersion +> +DarkNet::DarkNet() : + inputChannel(0), + inputWidth(0), + inputHeight(0), + numClasses(0), + weights("none") +{ + // Nothing to do here. +} + +template< + typename OutputLayerType, + typename InitializationRuleType, + size_t DarkNetVersion +> +DarkNet::DarkNet( + const size_t inputChannel, + const size_t inputWidth, + const size_t inputHeight, + const size_t numClasses, + const std::string& weights, + const bool includeTop) : + DarkNet( + std::tuple( + inputChannel, + inputWidth, + inputHeight), + numClasses, + weights, + includeTop) +{ + // Nothing to do here. +} + +template< + typename OutputLayerType, + typename InitializationRuleType, + size_t DarkNetVersion +> +DarkNet::DarkNet( + const std::tuple inputShape, + const size_t numClasses, + const std::string& weights, + const bool includeTop) : + inputChannel(std::get<0>(inputShape)), + inputWidth(std::get<1>(inputShape)), + inputHeight(std::get<2>(inputShape)), + numClasses(numClasses), + weights(weights) +{ + mlpack::Log::Assert(DarkNetVersion == 19 || DarkNetVersion == 53, + "Incorrect DarkNet version. Possible values are 19 and 53. \ + Trying to find version : " + std::to_string(DarkNetVersion) + "."); + + if (weights == "imagenet") + { + // Download weights here. + LoadModel("./../weights/darknet/darknet" + std::to_string(DarkNetVersion) + + "_imagenet.bin"); + return; + } + else if (weights != "none") + { + LoadModel(weights); + return; + } + + if (DarkNetVersion == 19) + { + darkNet.Add(new IdentityLayer<>()); + + // Convolution and activation function in a block. + ConvolutionBlock(inputChannel, 32, 3, 3, 1, 1, 1, 1, true); + PoolingBlock(); + ConvolutionBlock(32, 64, 3, 3, 1, 1, 1, 1, true); + PoolingBlock(); + DarkNet19SequentialBlock(64, 3, 3, 1, 1); + PoolingBlock(); + DarkNet19SequentialBlock(128, 3, 3, 1, 1); + PoolingBlock(); + DarkNet19SequentialBlock(256, 3, 3, 1, 1); + ConvolutionBlock(512, 256, 1, 1, 1, 1, 1, 1, true); + ConvolutionBlock(256, 512, 3, 3, 1, 1, 1, 1, true); + PoolingBlock(); + DarkNet19SequentialBlock(512, 3, 3, 1, 1); + ConvolutionBlock(1024, 512, 1, 1, 1, 1, 1, 1, true); + ConvolutionBlock(512, 1024, 3, 3, 1, 1, 1, 1, true); + + if (includeTop) + { + darkNet.Add(new Convolution<>(1024, numClasses, 1, 1, + 1, 1, 0, 0, inputWidth, inputHeight)); + darkNet.Add(new AdaptiveMeanPooling<>(1, 1)); + darkNet.Add(new LogSoftMax<>()); + } + + darkNet.ResetParameters(); + } + else if (DarkNetVersion == 53) + { + darkNet.Add(new IdentityLayer<>()); + ConvolutionBlock(inputChannel, 32, 3, 3, 1, 1, 1, 1, true, 1e-2); + ConvolutionBlock(32, 64, 3, 3, 2, 2, 1, 1, true, 1e-2); + + // Let's automate this a bit. + size_t curChannels = 64; + + // Residual block configuration for DarkNet 53. + std::vector residualBlockConfig = {1, 2, 8, 8, 4}; + for (size_t blockCount : residualBlockConfig) + { + for (size_t i = 0; i < blockCount; i++) + DarkNet53ResidualBlock(curChannels); + + if (blockCount != 4) + { + ConvolutionBlock(curChannels, curChannels * 2, 3, 3, + 2, 2, 1, 1, true, 1e-2); + curChannels = curChannels * 2; + } + } + + if (includeTop) + { + darkNet.Add(new AdaptiveMeanPooling<>(1, 1)); + darkNet.Add(new Linear<>(curChannels, numClasses)); + } + + darkNet.ResetParameters(); + } +} + +template< + typename OutputLayerType, + typename InitializationRuleType, + size_t DarkNetVersion +> +void DarkNet< + OutputLayerType, InitializationRuleType, DarkNetVersion +>::LoadModel(const std::string& filePath) +{ + data::Load(filePath, "DarkNet", darkNet); + Log::Info << "Loaded model" << std::endl; +} + +template< + typename OutputLayerType, + typename InitializationRuleType, + size_t DarkNetVersion +> +void DarkNet< + OutputLayerType, InitializationRuleType, DarkNetVersion +>::SaveModel(const std::string& filePath) +{ + Log::Info<< "Saving model." << std::endl; + data::Save(filePath, "DarkNet", darkNet); + Log::Info << "Model saved in " << filePath << "." << std::endl; +} + +} // namespace ann +} // namespace mlpack + +#endif diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index ded2245a..dacc08f8 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -8,6 +8,7 @@ include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../") add_executable(models_test augmentation_tests.cpp + ffn_model_tests.cpp dataloader_tests.cpp utils_tests.cpp ) diff --git a/tests/ffn_model_tests.cpp b/tests/ffn_model_tests.cpp new file mode 100644 index 00000000..84e42eb7 --- /dev/null +++ b/tests/ffn_model_tests.cpp @@ -0,0 +1,45 @@ +/** + * @file model_tests.cpp + * @author Kartik Dutt + * + * Tests for various functionalities and performance of models. + * + * mlpack is free software; you may redistribute it and/or modify it under the + * terms of the 3-clause BSD license. You should have received a copy of the + * 3-clause BSD license along with mlpack. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ +#define BOOST_TEST_DYN_LINK +#include +#include +#include +#include +#include + +// Use namespaces for convenience. +using namespace boost::unit_test; + +BOOST_AUTO_TEST_SUITE(FFNModelsTests); + +/** + * Simple test for Darknet model. + */ +BOOST_AUTO_TEST_CASE(DarknetModelTest) +{ + mlpack::ann::DarkNet<> darknetModel(3, 224, 224, 1000); + arma::mat input(224 * 224 * 3, 1), output; + input.ones(); + + // Check output shape. + darknetModel.GetModel().Predict(input, output); + BOOST_REQUIRE_EQUAL(output.n_cols, 1); + BOOST_REQUIRE_EQUAL(output.n_rows, 1000); + + // Repeat for DarkNet-53. + mlpack::ann::DarkNet<> darknet53(3, 224, 224, 1000); + darknet53.GetModel().Predict(input, output); + BOOST_REQUIRE_EQUAL(output.n_cols, 1); + BOOST_REQUIRE_EQUAL(output.n_rows, 1000); +} + +BOOST_AUTO_TEST_SUITE_END(); diff --git a/utils/CMakeLists.txt b/utils/CMakeLists.txt index e7a48950..0eab52d3 100644 --- a/utils/CMakeLists.txt +++ b/utils/CMakeLists.txt @@ -3,7 +3,9 @@ project(utils) set(DIR_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/) -set(SOURCES utils.hpp) +set(SOURCES + utils.hpp + ensmallen_utils.hpp) foreach(file ${SOURCES}) set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file})