From 602549f67caf4d45196b45f204f9a469aa09b20f Mon Sep 17 00:00:00 2001 From: kartikdutt18 Date: Thu, 9 Jul 2020 10:58:52 +0530 Subject: [PATCH 1/8] Add yolo-v1 --- models/CMakeLists.txt | 1 + models/yolo/CMakeLists.txt | 16 +++ models/yolo/yolo.hpp | 242 +++++++++++++++++++++++++++++++++++++ models/yolo/yolo_impl.hpp | 168 +++++++++++++++++++++++++ 4 files changed, 427 insertions(+) create mode 100644 models/yolo/CMakeLists.txt create mode 100644 models/yolo/yolo.hpp create mode 100644 models/yolo/yolo_impl.hpp diff --git a/models/CMakeLists.txt b/models/CMakeLists.txt index c4bd5a8c..27ca3206 100644 --- a/models/CMakeLists.txt +++ b/models/CMakeLists.txt @@ -2,6 +2,7 @@ cmake_minimum_required(VERSION 3.1.0 FATAL_ERROR) project(models) add_subdirectory(darknet) +add_subdirectory(yolo) # Add directory name to sources. set(DIR_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/) diff --git a/models/yolo/CMakeLists.txt b/models/yolo/CMakeLists.txt new file mode 100644 index 00000000..180c33a7 --- /dev/null +++ b/models/yolo/CMakeLists.txt @@ -0,0 +1,16 @@ +cmake_minimum_required(VERSION 3.1.0 FATAL_ERROR) +project(yolo) + +set(DIR_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/) +include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../") + +set(SOURCES + yolo.hpp + yolo_impl.hpp +) + +foreach(file ${SOURCES}) + set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file}) +endforeach() + +set(DIRS ${DIRS} ${DIR_SRCS} PARENT_SCOPE) diff --git a/models/yolo/yolo.hpp b/models/yolo/yolo.hpp new file mode 100644 index 00000000..ea422257 --- /dev/null +++ b/models/yolo/yolo.hpp @@ -0,0 +1,242 @@ +/** + * @file yolo.hpp + * @author Kartik Dutt + * + * Definition of Yolo models. + * + * For more information, kindly refer to the following paper. + * + * Paper for YOLOv1. + * + * @code + * @article{Redmon2016, + * author = {Joseph Redmon, Santosh Divvala, Ross Girshick, Ali Farhadi}, + * title = {You Only Look Once : Unified, Real-Time Object Detection}, + * year = {2016}, + * url = {https://arxiv.org/pdf/1506.02640.pdf} + * } + * @endcode + * + * mlpack is free software; you may redistribute it and/or modify it under the + * terms of the 3-clause BSD license. You should have received a copy of the + * 3-clause BSD license along with mlpack. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ + +#ifndef MODELS_YOLO_HPP +#define MODELS_YOLO_HPP + +#include +#include +#include +#include +#include +#include +#include + +namespace mlpack { +namespace ann /** Artificial Neural Network. */{ + +/** + * Definition of a YOLO object detection models. + * + * @tparam OutputLayerType The output layer type used to evaluate the network. + * @tparam InitializationRuleType Rule used to initialize the weight matrix. + * @tparam YOLOVersion Version of YOLO model. + */ +template< + typename OutputLayerType = NegativeLogLikelihood<>, + typename InitializationRuleType = RandomInitialization, + std::string YOLOVersion = "v1-tiny" +> +class YOLO +{ + public: + //! Create the YOLO model. + YOLO(); + + /** + * YOLO constructor intializes input shape and number of classes. + * + * @param inputChannels Number of input channels of the input image. + * @param inputWidth Width of the input image. + * @param inputHeight Height of the input image. + * @param numClasses Optional number of classes to classify images into, + * only to be specified if includeTop is true. + * @param numBoxes Number of bounding boxes per image. + * @param featureSizeWidth Width of output feature map. + * @param featureSizeHeight Height of output feature map. + * @param weights One of 'none', 'voc'(pre-training on VOC-2012) or path to weights. + * @param includeTop Must be set to true if weights are set. + */ + YOLO(const size_t inputChannel, + const size_t inputWidth, + const size_t inputHeight, + const size_t numClasses = 20, + const size_t numBoxes = 2, + const size_t featureSizeWidth = 7, + const size_t featureSizeHeight = 7, + const std::string& weights = "none", + const bool includeTop = true); + + /** + * YOLO constructor intializes input shape and number of classes. + * + * @param inputShape A three-valued tuple indicating input shape. + * First value is number of Channels (Channels-First). + * Second value is input height. + * Third value is input width.. + * @param numClasses Optional number of classes to classify images into, + * only to be specified if includeTop is true. + * @param numBoxes Number of bounding boxes per image. + * @param featureShape A twp-valued tuple indicating width and height of output feature + * map. + * @param weights One of 'none', 'cifar10'(pre-training on CIFAR10) or path to weights. + */ + YOLO(const std::tuple inputShape, + const size_t numClasses = 1000, + const size_t numBoxes = 2, + const std::tuple featureShape, + const std::string& weights = "none", + const bool includeTop = true); + + //! Get Layers of the model. + FFN& GetModel() { return yolo; } + + //! Load weights into the model. + void LoadModel(const std::string& filePath); + + //! Save weights for the model. + void SaveModel(const std::string& filePath); + + private: + /** + * Adds Convolution Block. + * + * @tparam SequentialType Layer type in which convolution block will + * be added. + * + * @param inSize Number of input maps. + * @param outSize Number of output maps. + * @param kernelWidth Width of the filter/kernel. + * @param kernelHeight Height of the filter/kernel. + * @param strideWidth Stride of filter application in the x direction. + * @param strideHeight Stride of filter application in the y direction. + * @param padW Padding width of the input. + * @param padH Padding height of the input. + * @param batchNorm Boolean to determine whether a batch normalization + * layer is added. + * @param baseLayer Layer in which Convolution block will be added, if + * NULL added to YOLO FFN. + */ + template> + void ConvolutionBlock(const size_t inSize, + const size_t outSize, + const size_t kernelWidth, + const size_t kernelHeight, + const size_t strideWidth = 1, + const size_t strideHeight = 1, + const size_t padW = 0, + const size_t padH = 0, + const bool batchNorm = false, + SequentialType* baseLayer = NULL) + { + Sequential<>* bottleNeck = new Sequential<>(); + bottleNeck->Add(new Convolution<>(inSize, outSize, kernelWidth, + kernelHeight, strideWidth, strideHeight, padW, padH, inputWidth, + inputHeight)); + + // Update inputWidth and input Height. + inputWidth = ConvOutSize(inputWidth, kernelWidth, strideWidth, padW); + inputHeight = ConvOutSize(inputHeight, kernelHeight, strideHeight, padH); + + if (batchNorm) + { + bottleNeck->Add(new BatchNorm<>(outSize, 1e-8, false)); + } + + bottleNeck->Add(new LeakyReLU<>()); + + if (baseLayer != NULL) + baseLayer->Add(bottleNeck); + else + yolo.Add(bottleNeck); + } + + /** + * Adds Pooling Block. + * + * @param factor The factor by which input dimensions will be divided. + * @param type One of "max" or "mean". Determines whether add mean pooling + * layer or max pooling layer. + */ + void PoolingBlock(const size_t factor = 2, + const std::string type = "max") + { + if (type == "max") + { + yolo.Add(new AdaptiveMaxPooling<>(std::ceil(inputWidth * 1.0 / factor), + std::ceil(inputHeight * 1.0 / factor))); + } + else + { + yolo.Add(new AdaptiveMeanPooling<>(std::ceil(inputWidth * 1.0 / + factor), std::ceil(inputHeight * 1.0 / factor))); + } + + // Update inputWidth and inputHeight. + inputWidth = std::ceil(inputWidth * 1.0 / factor); + inputHeight = std::ceil(inputHeight * 1.0 / factor); + } + + /** + * Return the convolution output size. + * + * @param size The size of the input (row or column). + * @param k The size of the filter (width or height). + * @param s The stride size (x or y direction). + * @param padding The size of the padding (width or height) on one side. + * @return The convolution output size. + */ + size_t ConvOutSize(const size_t size, + const size_t k, + const size_t s, + const size_t padding) + { + return std::floor(size + 2 * padding - k) / s + 1; + } + + //! Locally stored YOLO Model. + FFN yolo; + + //! Locally stored width of the image. + size_t inputWidth; + + //! Locally stored height of the image. + size_t inputHeight; + + //! Locally stored number of channels in the image. + size_t inputChannel; + + //! Locally stored number of output classes. + size_t numClasses; + + //! Locally stored number of output bounding boxes. + size_t numBoxes; + + //! Locally stored width of output feature map. + size_t featureWidth; + + //! Locally stored height of output feature map. + size_t featureHeight; + + //! Locally stored type of pre-trained weights. + std::string weights; +}; // YOLO class. + +} // namespace ann +} // namespace mlpack + +# include "yolo_impl.hpp" + +#endif diff --git a/models/yolo/yolo_impl.hpp b/models/yolo/yolo_impl.hpp new file mode 100644 index 00000000..436941c1 --- /dev/null +++ b/models/yolo/yolo_impl.hpp @@ -0,0 +1,168 @@ +/** + * @file yolo_impl.hpp + * @author Kartik Dutt + * + * Implementation of YOLO models using mlpack. + * + * mlpack is free software; you may redistribute it and/or modify it under the + * terms of the 3-clause BSD license. You should have received a copy of the + * 3-clause BSD license along with mlpack. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ +#ifndef MODELS_YOLO_IMPL_HPP +#define MODELS_YOLO_IMPL_HPP + +#include "yolo.hpp" + +namespace mlpack { +namespace ann { + +template< + typename OutputLayerType, + typename InitializationRuleType, + std::string YOLOVersion = "v1-tiny" +> +YOLO::YOLO() : + inputChannel(0), + inputWidth(0), + inputHeight(0), + numClasses(0), + numBoxes(0), + featureWidth(0), + featureHeight(0), + weights("none") +{ + // Nothing to do here. +} + +template< + typename OutputLayerType, + typename InitializationRuleType, + std::string YOLOVersion = "v1-tiny" +> +YOLO::YOLO( + const size_t inputChannel, + const size_t inputWidth, + const size_t inputHeight, + const size_t numClasses, + const size_t numBoxes, + const size_t featureWidth, + const size_t featureHeight, + const std::string& weights, + const bool includeTop) : + YOLO( + std::tuple( + inputChannel, + inputWidth, + inputHeight), + numClasses, + numBoxes, + std::tuple(featureWidth, featureHeight), + weights, + includeTop) +{ + // Nothing to do here. +} + +template< + typename OutputLayerType, + typename InitializationRuleType, + std::string YOLOVersion +> +YOLO::YOLO( + const std::tuple inputShape, + const size_t numClasses, + const size_t numBoxes, + const std::tuple featureShape, + const std::string& weights, + const bool includeTop) : + inputChannel(std::get<0>(inputShape)), + inputWidth(std::get<1>(inputShape)), + inputHeight(std::get<2>(inputShape)), + numClasses(numClasses), + numBoxes(numBoxes), + featureWidth(std::get<0>(featureShape)), + featureHeight(std::get<1>(featureShape)), + weights(weights) +{ + std::set supportedVersion({"v1-tiny"}); + mlpack::Log::Assert(supportedVersion.count(YOLOVersion), + "Unsupported YOLO version. Trying to find :", YOLOVersion); + + if (weights == "voc") + { + // Download weights here. + LoadModel("./../weights/YOLO/yolo" + YOLOVersion + "_voc.bin"); + return; + } + else if (weights != "none") + { + LoadModel(weights); + return; + } + + if (YOLOVersion == "v1-tiny") + { + yolo.Add(new IdentityLayer<>()); + + // Convolution and activation function in a block. + ConvolutionBlock(inputChannel, 16, 3, 3, 1, 1, 1, 1, true); + PoolingBlock(2); + + size_t numBlocks = 5; + size_t outChannels = 16; + for (size_t blockId = 0; blockId < 4; blockId++) + { + ConvolutionBlock(outChannels, outChannels * 2, 3, 3, 1, 1, 1, 1, true); + PoolingBlock(2); + outChannels *= 2; + } + + numBlocks = 2; + for (size_t blockId = 0; blockId < numBlocks; blockId++) + ConvolutionBlock(outChannels, outChannels, 3, 3, 1, 1, 1, 1, true); + + if (includeTop) + { + yolo.Add(new Linear<>(inputWidth * inputHeight * outChannels, 4096)); + yolo.Add(new LeakyReLU<>()); + yolo.Add(4096, featureWidth * featureHeight * (5 * numBoxes + numClasses)); + yolo.Add(new Sigmoid<>()); + // See if we need to reshape here. + } + } + yolo.ResetParameters(); + } +} + +template< + typename OutputLayerType, + typename InitializationRuleType, + std::string YOLOVersion +> +void YOLO< + OutputLayerType, InitializationRuleType, YOLOVersion +>::LoadModel(const std::string& filePath) +{ + data::Load(filePath, "yolo" + YOLOVersion, yolo); + Log::Info << "Loaded model." << std::endl; +} + +template< + typename OutputLayerType, + typename InitializationRuleType, + std::string YOLOVersion +> +void YOLO< + OutputLayerType, InitializationRuleType, YOLOVersion +>::SaveModel(const std::string& filePath) +{ + Log::Info<< "Saving model." << std::endl; + data::Save(filePath, "yolo" + YOLOVerson, yolo); + Log::Info << "Model saved in " << filePath << "." << std::endl; +} + +} // namespace ann +} // namespace mlpack + +#endif \ No newline at end of file From b2e7bcebd2a7f09e1196b9569b36c263db4c1371 Mon Sep 17 00:00:00 2001 From: kartikdutt18 Date: Sat, 22 Aug 2020 20:02:38 +0530 Subject: [PATCH 2/8] Style fixes --- models/yolo/yolo_impl.hpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/models/yolo/yolo_impl.hpp b/models/yolo/yolo_impl.hpp index 436941c1..8d918b3f 100644 --- a/models/yolo/yolo_impl.hpp +++ b/models/yolo/yolo_impl.hpp @@ -126,11 +126,11 @@ YOLO::YOLO( { yolo.Add(new Linear<>(inputWidth * inputHeight * outChannels, 4096)); yolo.Add(new LeakyReLU<>()); - yolo.Add(4096, featureWidth * featureHeight * (5 * numBoxes + numClasses)); + yolo.Add(4096, featureWidth * featureHeight * (5 * + numBoxes + numClasses)); yolo.Add(new Sigmoid<>()); - // See if we need to reshape here. } - } + yolo.ResetParameters(); } } @@ -165,4 +165,4 @@ void YOLO< } // namespace ann } // namespace mlpack -#endif \ No newline at end of file +#endif From 34f09632a690fb597c1f453600e724598516244c Mon Sep 17 00:00:00 2001 From: kartikdutt18 Date: Tue, 25 Aug 2020 15:13:03 +0530 Subject: [PATCH 3/8] Model fixes to match output --- models/yolo/yolo_impl.hpp | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/models/yolo/yolo_impl.hpp b/models/yolo/yolo_impl.hpp index 8d918b3f..d4c0bfff 100644 --- a/models/yolo/yolo_impl.hpp +++ b/models/yolo/yolo_impl.hpp @@ -111,23 +111,22 @@ YOLO::YOLO( size_t numBlocks = 5; size_t outChannels = 16; - for (size_t blockId = 0; blockId < 4; blockId++) + for (size_t blockId = 0; blockId < numBlocks; blockId++) { ConvolutionBlock(outChannels, outChannels * 2, 3, 3, 1, 1, 1, 1, true); PoolingBlock(2); outChannels *= 2; } - numBlocks = 2; - for (size_t blockId = 0; blockId < numBlocks; blockId++) - ConvolutionBlock(outChannels, outChannels, 3, 3, 1, 1, 1, 1, true); + ConvolutionBlock(outChannels, outChannels * 2, 3, 3, 1, 1, 1, 1, true); + outChannels *= 2; + ConvolutionBlock(outChannels, 256, 3, 3, 1, 1, 1, 1, true); + outChannels 256; if (includeTop) { - yolo.Add(new Linear<>(inputWidth * inputHeight * outChannels, 4096)); - yolo.Add(new LeakyReLU<>()); - yolo.Add(4096, featureWidth * featureHeight * (5 * - numBoxes + numClasses)); + yolo.Add(new Linear<>(inputWidth * inputHeight * outChannels, + featureWidth * featureHeight * (5 * numBoxes + numClasses))); yolo.Add(new Sigmoid<>()); } From fa6eeff4c11350a475e8c0b5c80ace46fd78224f Mon Sep 17 00:00:00 2001 From: kartikdutt18 Date: Tue, 25 Aug 2020 22:24:23 +0530 Subject: [PATCH 4/8] Last changes, this works... add tests for verification --- models/yolo/yolo.hpp | 31 ++++++++++++++++------ models/yolo/yolo_impl.hpp | 54 +++++++++++++++++++-------------------- 2 files changed, 50 insertions(+), 35 deletions(-) diff --git a/models/yolo/yolo.hpp b/models/yolo/yolo.hpp index ea422257..624c5f69 100644 --- a/models/yolo/yolo.hpp +++ b/models/yolo/yolo.hpp @@ -31,8 +31,7 @@ #include #include #include -#include -#include + namespace mlpack { namespace ann /** Artificial Neural Network. */{ @@ -46,8 +45,7 @@ namespace ann /** Artificial Neural Network. */{ */ template< typename OutputLayerType = NegativeLogLikelihood<>, - typename InitializationRuleType = RandomInitialization, - std::string YOLOVersion = "v1-tiny" + typename InitializationRuleType = RandomInitialization > class YOLO { @@ -61,6 +59,7 @@ class YOLO * @param inputChannels Number of input channels of the input image. * @param inputWidth Width of the input image. * @param inputHeight Height of the input image. + * @param yoloVersion Version of YOLO model. * @param numClasses Optional number of classes to classify images into, * only to be specified if includeTop is true. * @param numBoxes Number of bounding boxes per image. @@ -72,6 +71,7 @@ class YOLO YOLO(const size_t inputChannel, const size_t inputWidth, const size_t inputHeight, + const std::string yoloVersion = "v1-tiny", const size_t numClasses = 20, const size_t numBoxes = 2, const size_t featureSizeWidth = 7, @@ -85,7 +85,8 @@ class YOLO * @param inputShape A three-valued tuple indicating input shape. * First value is number of Channels (Channels-First). * Second value is input height. - * Third value is input width.. + * Third value is input width. + * @param yoloVersion Version of YOLO model. * @param numClasses Optional number of classes to classify images into, * only to be specified if includeTop is true. * @param numBoxes Number of bounding boxes per image. @@ -94,9 +95,10 @@ class YOLO * @param weights One of 'none', 'cifar10'(pre-training on CIFAR10) or path to weights. */ YOLO(const std::tuple inputShape, + const std::string yoloVersion = "v1-tiny", const size_t numClasses = 1000, const size_t numBoxes = 2, - const std::tuple featureShape, + const std::tuple featureShape = {7, 7}, const std::string& weights = "none", const bool includeTop = true); @@ -146,16 +148,21 @@ class YOLO kernelHeight, strideWidth, strideHeight, padW, padH, inputWidth, inputHeight)); - // Update inputWidth and input Height. + mlpack::Log::Info << "Conv Layer. "; + mlpack::Log::Info << "(" << inputWidth << ", " << inputHeight << + ", " << inSize << ") ----> "; + inputWidth = ConvOutSize(inputWidth, kernelWidth, strideWidth, padW); inputHeight = ConvOutSize(inputHeight, kernelHeight, strideHeight, padH); + mlpack::Log::Info << "(" << inputWidth << ", " << inputHeight << + ", " << outSize << ")" << std::endl; if (batchNorm) { bottleNeck->Add(new BatchNorm<>(outSize, 1e-8, false)); } - bottleNeck->Add(new LeakyReLU<>()); + bottleNeck->Add(new LeakyReLU<>(0.01)); if (baseLayer != NULL) baseLayer->Add(bottleNeck); @@ -184,9 +191,14 @@ class YOLO factor), std::ceil(inputHeight * 1.0 / factor))); } + mlpack::Log::Info << "Pooling Layer. "; + mlpack::Log::Info << "(" << inputWidth << ", " << inputHeight << + ") ----> "; // Update inputWidth and inputHeight. inputWidth = std::ceil(inputWidth * 1.0 / factor); inputHeight = std::ceil(inputHeight * 1.0 / factor); + + mlpack::Log::Info << "(" << inputWidth << ", " << inputHeight << ")" << std::endl; } /** @@ -232,6 +244,9 @@ class YOLO //! Locally stored type of pre-trained weights. std::string weights; + + //! Locally stored version of yolo model. + std::string yoloVersion; }; // YOLO class. } // namespace ann diff --git a/models/yolo/yolo_impl.hpp b/models/yolo/yolo_impl.hpp index d4c0bfff..4f1bda91 100644 --- a/models/yolo/yolo_impl.hpp +++ b/models/yolo/yolo_impl.hpp @@ -19,10 +19,9 @@ namespace ann { template< typename OutputLayerType, - typename InitializationRuleType, - std::string YOLOVersion = "v1-tiny" + typename InitializationRuleType > -YOLO::YOLO() : +YOLO::YOLO() : inputChannel(0), inputWidth(0), inputHeight(0), @@ -30,31 +29,33 @@ YOLO::YOLO() : numBoxes(0), featureWidth(0), featureHeight(0), - weights("none") + weights("none"), + yoloVersion("none") { // Nothing to do here. } template< typename OutputLayerType, - typename InitializationRuleType, - std::string YOLOVersion = "v1-tiny" + typename InitializationRuleType > -YOLO::YOLO( +YOLO::YOLO( const size_t inputChannel, const size_t inputWidth, const size_t inputHeight, + const std::string yoloVersion, const size_t numClasses, const size_t numBoxes, const size_t featureWidth, const size_t featureHeight, const std::string& weights, const bool includeTop) : - YOLO( + YOLO( std::tuple( inputChannel, inputWidth, inputHeight), + yoloVersion, numClasses, numBoxes, std::tuple(featureWidth, featureHeight), @@ -66,11 +67,11 @@ YOLO::YOLO( template< typename OutputLayerType, - typename InitializationRuleType, - std::string YOLOVersion + typename InitializationRuleType > -YOLO::YOLO( +YOLO::YOLO( const std::tuple inputShape, + const std::string yoloVersion, const size_t numClasses, const size_t numBoxes, const std::tuple featureShape, @@ -83,16 +84,17 @@ YOLO::YOLO( numBoxes(numBoxes), featureWidth(std::get<0>(featureShape)), featureHeight(std::get<1>(featureShape)), - weights(weights) + weights(weights), + yoloVersion(yoloVersion) { - std::set supportedVersion({"v1-tiny"}); - mlpack::Log::Assert(supportedVersion.count(YOLOVersion), - "Unsupported YOLO version. Trying to find :", YOLOVersion); + std::set supportedVersion({"v1-tiny"}); + mlpack::Log::Assert(supportedVersion.count(yoloVersion), + "Unsupported YOLO version. Trying to find :" + yoloVersion); if (weights == "voc") { // Download weights here. - LoadModel("./../weights/YOLO/yolo" + YOLOVersion + "_voc.bin"); + LoadModel("./../weights/YOLO/yolo" + yoloVersion + "_voc.bin"); return; } else if (weights != "none") @@ -101,7 +103,7 @@ YOLO::YOLO( return; } - if (YOLOVersion == "v1-tiny") + if (yoloVersion == "v1-tiny") { yolo.Add(new IdentityLayer<>()); @@ -121,13 +123,13 @@ YOLO::YOLO( ConvolutionBlock(outChannels, outChannels * 2, 3, 3, 1, 1, 1, 1, true); outChannels *= 2; ConvolutionBlock(outChannels, 256, 3, 3, 1, 1, 1, 1, true); - outChannels 256; + outChannels = 256; if (includeTop) { yolo.Add(new Linear<>(inputWidth * inputHeight * outChannels, featureWidth * featureHeight * (5 * numBoxes + numClasses))); - yolo.Add(new Sigmoid<>()); + yolo.Add(new SigmoidLayer<>()); } yolo.ResetParameters(); @@ -136,28 +138,26 @@ YOLO::YOLO( template< typename OutputLayerType, - typename InitializationRuleType, - std::string YOLOVersion + typename InitializationRuleType > void YOLO< - OutputLayerType, InitializationRuleType, YOLOVersion + OutputLayerType, InitializationRuleType >::LoadModel(const std::string& filePath) { - data::Load(filePath, "yolo" + YOLOVersion, yolo); + data::Load(filePath, "yolo" + yoloVersion, yolo); Log::Info << "Loaded model." << std::endl; } template< typename OutputLayerType, - typename InitializationRuleType, - std::string YOLOVersion + typename InitializationRuleType > void YOLO< - OutputLayerType, InitializationRuleType, YOLOVersion + OutputLayerType, InitializationRuleType >::SaveModel(const std::string& filePath) { Log::Info<< "Saving model." << std::endl; - data::Save(filePath, "yolo" + YOLOVerson, yolo); + data::Save(filePath, "yolo" + yoloVersion, yolo); Log::Info << "Model saved in " << filePath << "." << std::endl; } From 0a8b2bcec8396ca068ba080234dc81937e25d725 Mon Sep 17 00:00:00 2001 From: kartikdutt18 Date: Wed, 26 Aug 2020 16:08:24 +0530 Subject: [PATCH 5/8] Add tests --- tests/ffn_model_tests.cpp | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/ffn_model_tests.cpp b/tests/ffn_model_tests.cpp index 84e42eb7..d1c02e6d 100644 --- a/tests/ffn_model_tests.cpp +++ b/tests/ffn_model_tests.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include // Use namespaces for convenience. @@ -42,4 +43,19 @@ BOOST_AUTO_TEST_CASE(DarknetModelTest) BOOST_REQUIRE_EQUAL(output.n_rows, 1000); } +/** + * Simple test for YOLOv1 model. + */ +BOOST_AUTO_TEST_CASE(YOLOV1ModelTest) +{ + mlpack::ann::YOLO<> yolo(3, 448, 448); + arma::mat input(448 * 448 * 3, 1), output; + input.ones(); + + // Check output shape. + yolo.GetModel().Predict(input, output); + BOOST_REQUIRE_EQUAL(output.n_cols, 1); + BOOST_REQUIRE_EQUAL(output.n_rows, 7 * 7 * (5 * 2 + 20)); +} + BOOST_AUTO_TEST_SUITE_END(); From 5ed980c4f3d88d8c842c0b202ca8bb5945e3800f Mon Sep 17 00:00:00 2001 From: kartikdutt18 <39593019+kartikdutt18@users.noreply.github.com> Date: Wed, 26 Aug 2020 20:35:07 +0530 Subject: [PATCH 6/8] Update models/yolo/yolo.hpp Co-authored-by: sy0814k --- models/yolo/yolo.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/yolo/yolo.hpp b/models/yolo/yolo.hpp index 624c5f69..92207359 100644 --- a/models/yolo/yolo.hpp +++ b/models/yolo/yolo.hpp @@ -91,7 +91,7 @@ class YOLO * only to be specified if includeTop is true. * @param numBoxes Number of bounding boxes per image. * @param featureShape A twp-valued tuple indicating width and height of output feature - * map. + * map. * @param weights One of 'none', 'cifar10'(pre-training on CIFAR10) or path to weights. */ YOLO(const std::tuple inputShape, From 4c34764aa97b5bfe37ec473107a2a8e6fed450b1 Mon Sep 17 00:00:00 2001 From: kartikdutt18 <39593019+kartikdutt18@users.noreply.github.com> Date: Wed, 26 Aug 2020 20:35:16 +0530 Subject: [PATCH 7/8] Update models/yolo/yolo.hpp Co-authored-by: sy0814k --- models/yolo/yolo.hpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/models/yolo/yolo.hpp b/models/yolo/yolo.hpp index 92207359..b648c799 100644 --- a/models/yolo/yolo.hpp +++ b/models/yolo/yolo.hpp @@ -158,9 +158,7 @@ class YOLO ", " << outSize << ")" << std::endl; if (batchNorm) - { bottleNeck->Add(new BatchNorm<>(outSize, 1e-8, false)); - } bottleNeck->Add(new LeakyReLU<>(0.01)); From f1c3bbf741adba1a43243091968257ac71119ac0 Mon Sep 17 00:00:00 2001 From: kartikdutt18 Date: Thu, 27 Aug 2020 18:58:00 +0530 Subject: [PATCH 8/8] Style fixes --- models/yolo/yolo.hpp | 21 ++++++++++----------- models/yolo/yolo_impl.hpp | 8 ++++---- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/models/yolo/yolo.hpp b/models/yolo/yolo.hpp index b648c799..d64986c5 100644 --- a/models/yolo/yolo.hpp +++ b/models/yolo/yolo.hpp @@ -61,7 +61,7 @@ class YOLO * @param inputHeight Height of the input image. * @param yoloVersion Version of YOLO model. * @param numClasses Optional number of classes to classify images into, - * only to be specified if includeTop is true. + * only to be specified if includeTop is true. * @param numBoxes Number of bounding boxes per image. * @param featureSizeWidth Width of output feature map. * @param featureSizeHeight Height of output feature map. @@ -83,16 +83,15 @@ class YOLO * YOLO constructor intializes input shape and number of classes. * * @param inputShape A three-valued tuple indicating input shape. - * First value is number of Channels (Channels-First). - * Second value is input height. - * Third value is input width. + * First value is number of Channels (Channels-First). + * Second value is input height. Third value is input width. * @param yoloVersion Version of YOLO model. * @param numClasses Optional number of classes to classify images into, - * only to be specified if includeTop is true. + * only to be specified if includeTop is true. * @param numBoxes Number of bounding boxes per image. * @param featureShape A twp-valued tuple indicating width and height of output feature - * map. - * @param weights One of 'none', 'cifar10'(pre-training on CIFAR10) or path to weights. + * map. + * @param weights One of 'none', 'voc'(pre-training on VOC) or path to weights. */ YOLO(const std::tuple inputShape, const std::string yoloVersion = "v1-tiny", @@ -116,7 +115,7 @@ class YOLO * Adds Convolution Block. * * @tparam SequentialType Layer type in which convolution block will - * be added. + * be added. * * @param inSize Number of input maps. * @param outSize Number of output maps. @@ -127,9 +126,9 @@ class YOLO * @param padW Padding width of the input. * @param padH Padding height of the input. * @param batchNorm Boolean to determine whether a batch normalization - * layer is added. + * layer is added. * @param baseLayer Layer in which Convolution block will be added, if - * NULL added to YOLO FFN. + * NULL added to YOLO FFN. */ template> void ConvolutionBlock(const size_t inSize, @@ -173,7 +172,7 @@ class YOLO * * @param factor The factor by which input dimensions will be divided. * @param type One of "max" or "mean". Determines whether add mean pooling - * layer or max pooling layer. + * layer or max pooling layer. */ void PoolingBlock(const size_t factor = 2, const std::string type = "max") diff --git a/models/yolo/yolo_impl.hpp b/models/yolo/yolo_impl.hpp index 4f1bda91..70c10904 100644 --- a/models/yolo/yolo_impl.hpp +++ b/models/yolo/yolo_impl.hpp @@ -66,8 +66,8 @@ YOLO::YOLO( } template< - typename OutputLayerType, - typename InitializationRuleType + typename OutputLayerType, + typename InitializationRuleType > YOLO::YOLO( const std::tuple inputShape, @@ -149,8 +149,8 @@ void YOLO< } template< - typename OutputLayerType, - typename InitializationRuleType + typename OutputLayerType, + typename InitializationRuleType > void YOLO< OutputLayerType, InitializationRuleType