-
Notifications
You must be signed in to change notification settings - Fork 41
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add basic definition of models, Needs to be trained and tested
Trained a lenet1 model Add All Weights
- Loading branch information
1 parent
e70d009
commit c4113eb
Showing
14 changed files
with
632 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -216,6 +216,8 @@ set(DIRS | |
utils/ | ||
dataloader/ | ||
tests/ | ||
models/ | ||
computer_vision/ | ||
) | ||
|
||
foreach(dir ${DIRS}) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
cmake_minimum_required(VERSION 3.1.0 FATAL_ERROR) | ||
project(computer_vision) | ||
|
||
add_subdirectory(object_classification/) | ||
|
||
# Add directory name to sources. | ||
set(DIR_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/) | ||
|
||
foreach(file ${SOURCES}) | ||
set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file}) | ||
endforeach() | ||
|
||
set(DIRS ${DIRS} ${DIR_SRCS} PARENT_SCOPE) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
cmake_minimum_required(VERSION 3.1.0 FATAL_ERROR) | ||
project(object_classification) | ||
|
||
set(MODEL_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/) | ||
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../") | ||
|
||
set(SOURCES | ||
object_classification.cpp | ||
) | ||
|
||
foreach(file ${SOURCES}) | ||
string( REPLACE ".cpp" "" name ${file}) | ||
add_executable(${name} ${MODEL_SOURCE_DIR}/${file}) | ||
target_link_libraries(${name} | ||
${COMPILER_SUPPORT_LIBRARIES} | ||
${ARMADILLO_LIBRARIES} | ||
${Boost_FILESYSTEM_LIBRARY} | ||
${Boost_UNIT_TEST_FRAMEWORK_LIBRARY} | ||
${Boost_SYSTEM_LIBRARY} | ||
${Boost_SERIALIZATION_LIBRARY} | ||
${MLPACK_LIBRARIES} | ||
) | ||
endforeach() |
57 changes: 57 additions & 0 deletions
57
computer_vision/object_classification/object_classification.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
/** | ||
* @file object_classification.hpp | ||
* @author Kartik Dutt | ||
* | ||
* Contains implementation of object classification suite. It can be used | ||
* to select object classification model, it's parameter dataset and | ||
* other training parameters. | ||
* | ||
* NOTE: This code needs to be adapted as this implementation doesn't support | ||
* Command Line Arguments. | ||
* | ||
* mlpack is free software; you may redistribute it and/or modify it under the | ||
* terms of the 3-clause BSD license. You should have received a copy of the | ||
* 3-clause BSD license along with mlpack. If not, see | ||
* http://www.opensource.org/licenses/BSD-3-Clause for more information. | ||
*/ | ||
#include <dataloader/dataloader.hpp> | ||
#include <models/lenet/lenet.hpp> | ||
#include <ensmallen.hpp> | ||
|
||
using namespace mlpack; | ||
using namespace mlpack::ann; | ||
using namespace arma; | ||
using namespace std; | ||
using namespace ens; | ||
|
||
int main() | ||
{ | ||
const int EPOCHS = 3; | ||
const double STEP_SIZE = 5e-3; | ||
const int BATCH_SIZE = 32; | ||
const double RATIO = 0.2; | ||
|
||
DataLoader<> dataloader("mnist", true, RATIO); | ||
|
||
constexpr size_t ver = 5; | ||
LeNet<mlpack::ann::NegativeLogLikelihood<>, | ||
mlpack::ann::RandomInitialization, ver> module1(1, 28, 28, 10); | ||
cout << "Training." << endl; | ||
|
||
SGD<AdamUpdate> optimizer(STEP_SIZE, BATCH_SIZE, | ||
EPOCHS * (ver / 2) * dataloader.TrainLabels().n_cols, | ||
1e-8, | ||
true, | ||
AdamUpdate(1e-8, 0.9, 0.999)); | ||
|
||
module1.GetModel().Train(dataloader.TrainFeatures(), | ||
dataloader.TrainLabels(), | ||
optimizer, | ||
ens::PrintLoss(), | ||
ens::ProgressBar(), | ||
ens::EarlyStopAtMinLoss()); | ||
|
||
module1.SaveModel("./../weights/lenet/lenet" + to_string(ver)+"_mnist.bin"); | ||
|
||
return 0; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
cmake_minimum_required(VERSION 3.1.0 FATAL_ERROR) | ||
project(models) | ||
|
||
add_subdirectory(lenet) | ||
|
||
# Add directory name to sources. | ||
set(DIR_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/) | ||
|
||
foreach(file ${SOURCES}) | ||
set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file}) | ||
endforeach() | ||
|
||
# Append sources (with directory name) to list of all models sources (used at | ||
# the parent scope). | ||
set(DIRS ${DIRS} ${DIR_SRCS} PARENT_SCOPE) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
cmake_minimum_required(VERSION 3.1.0 FATAL_ERROR) | ||
project(lenet) | ||
|
||
set(DIR_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/) | ||
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../") | ||
|
||
set(SOURCES | ||
lenet.hpp | ||
lenet_impl.hpp | ||
) | ||
|
||
foreach(file ${SOURCES}) | ||
set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file}) | ||
endforeach() | ||
|
||
set(DIRS ${DIRS} ${DIR_SRCS} PARENT_SCOPE) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,210 @@ | ||
/** | ||
* @file lenet.hpp | ||
* @author Eugene Freyman | ||
* @author Daivik Nema | ||
* @author Kartik Dutt | ||
* | ||
* Definition of LeNet generally used for object detection. | ||
* | ||
* For more information, kindly refer to the following paper. | ||
* | ||
* @code | ||
* @article{LeCun1998, | ||
* author = {Yann LeCun, Leon Bottou, Yoshua Bengio, Pattrick Haffner}, | ||
* title = {Gradient Based Learning Applied to Document Recognizition}, | ||
* journal = {IEEE}, | ||
* year = {1998}, | ||
* url = {http://yann.lecun.com/exdb/publis/pdf/lecun-01a.pdf} | ||
* } | ||
* @endcode | ||
* | ||
* mlpack is free software; you may redistribute it and/or modify it under the | ||
* terms of the 3-clause BSD license. You should have received a copy of the | ||
* 3-clause BSD license along with mlpack. If not, see | ||
* http://www.opensource.org/licenses/BSD-3-Clause for more information. | ||
*/ | ||
|
||
#ifndef MODELS_LENET_HPP | ||
#define MODELS_LENET_HPP | ||
|
||
#include <mlpack/core.hpp> | ||
#include <mlpack/methods/ann/layer/layer.hpp> | ||
#include <mlpack/methods/ann/ffn.hpp> | ||
#include <mlpack/methods/ann/layer/layer_types.hpp> | ||
#include <mlpack/methods/ann/init_rules/random_init.hpp> | ||
|
||
namespace mlpack { | ||
namespace ann /** Artificial Neural Network. */{ | ||
|
||
/** | ||
* Definition of a LeNet CNN. | ||
* | ||
* @tparam OutputLayerType The output layer type used to evaluate the network. | ||
* @tparam InitializationRuleType Rule used to initialize the weight matrix. | ||
* @tparam leNetVer Version of LeNet. | ||
*/ | ||
template< | ||
typename OutputLayerType = NegativeLogLikelihood<>, | ||
typename InitializationRuleType = RandomInitialization, | ||
size_t leNetVer = 1 | ||
> | ||
class LeNet | ||
{ | ||
public: | ||
//! Create the LeNet object. | ||
LeNet(); | ||
|
||
/** | ||
* LeNet constructor intializes input shape and number of classes. | ||
* | ||
* @param inputChannels Number of input channels of the input image. | ||
* @param inputWidth Width of the input image. | ||
* @param inputHeight Height of the input image. | ||
* @param numClasses Optional number of classes to classify images into, | ||
* only to be specified if includeTop is true. | ||
* @param weights One of 'none', 'mnist'(pre-training on mnist) or path to weights. | ||
*/ | ||
LeNet(const size_t inputChannel, | ||
const size_t inputWidth, | ||
const size_t inputHeight, | ||
const size_t numClasses = 1000, | ||
const std::string& weights = "none"); | ||
|
||
/** | ||
* LeNet constructor intializes input shape and number of classes. | ||
* | ||
* @param inputShape A three-valued tuple indicating input shape. | ||
* First value is number of Channels (Channels-First). | ||
* Second value is input height. | ||
* Third value is input width.. | ||
* @param numClasses Optional number of classes to classify images into, | ||
* only to be specified if includeTop is true. | ||
* @param weights One of 'none', 'mnist'(pre-training on MNIST) or path to weights. | ||
*/ | ||
LeNet(const std::tuple<size_t, size_t, size_t> inputShape, | ||
const size_t numClasses = 1000, | ||
const std::string& weights = "none"); | ||
|
||
//! Get Layers of the model. | ||
FFN<OutputLayerType, InitializationRuleType>& GetModel() { return leNet; } | ||
|
||
// Returns the model as a sequential layer. | ||
Sequential<> AsSequential(); | ||
|
||
//! Load weights into the model. | ||
void LoadModel(const std::string& filePath); | ||
|
||
//! Save weights for the model. | ||
void SaveModel(const std::string& filePath); | ||
|
||
private: | ||
/** | ||
* Adds Convolution Block. | ||
* | ||
* @param inSize Number of input maps. | ||
* @param outSize Number of output maps. | ||
* @param kernelWidth Width of the filter/kernel. | ||
* @param kernelHeight Height of the filter/kernel. | ||
* @param strideWidth Stride of filter application in the x direction. | ||
* @param strideHeight Stride of filter application in the y direction. | ||
* @param padW Padding width of the input. | ||
* @param padH Padding height of the input. | ||
*/ | ||
void ConvolutionBlock(const size_t inSize, | ||
const size_t outSize, | ||
const size_t kernelWidth, | ||
const size_t kernelHeight, | ||
const size_t strideWidth = 1, | ||
const size_t strideHeight = 1, | ||
const size_t padW = 0, | ||
const size_t padH = 0) | ||
{ | ||
leNet.Add<Convolution<>>(inSize, outSize, kernelWidth, | ||
kernelHeight, strideWidth, strideHeight, padW, padH, inputWidth, | ||
inputHeight); | ||
leNet.Add<LeakyReLU<>>(); | ||
|
||
// Update inputWidth and input Height. | ||
inputWidth = ConvOutSize(inputWidth, kernelWidth, strideWidth, padW); | ||
inputHeight = ConvOutSize(inputHeight, kernelHeight, strideHeight, padH); | ||
return; | ||
} | ||
|
||
/** | ||
* Adds Pooling Block. | ||
* | ||
* @param kernelWidth Width of the filter/kernel. | ||
* @param kernelHeight Height of the filter/kernel. | ||
* @param strideWidth Stride of filter application in the x direction. | ||
* @param strideHeight Stride of filter application in the y direction. | ||
*/ | ||
void PoolingBlock(const size_t kernelWidth, | ||
const size_t kernelHeight, | ||
const size_t strideWidth = 1, | ||
const size_t strideHeight = 1) | ||
{ | ||
leNet.Add<MaxPooling<>>(kernelWidth, kernelHeight, | ||
strideWidth, strideHeight, true); | ||
// Update inputWidth and inputHeight. | ||
inputWidth = PoolOutSize(inputWidth, kernelWidth, strideWidth); | ||
inputHeight = PoolOutSize(inputHeight, kernelHeight, strideHeight); | ||
return; | ||
} | ||
|
||
/** | ||
* Return the convolution output size. | ||
* | ||
* @param size The size of the input (row or column). | ||
* @param k The size of the filter (width or height). | ||
* @param s The stride size (x or y direction). | ||
* @param padding The size of the padding (width or height) on one side. | ||
* @return The convolution output size. | ||
*/ | ||
size_t ConvOutSize(const size_t size, | ||
const size_t k, | ||
const size_t s, | ||
const size_t padding) | ||
{ | ||
return std::floor(size + 2 * padding - k) / s + 1; | ||
} | ||
|
||
/** | ||
* Return the convolution output size. | ||
* | ||
* @param size The size of the input (row or column). | ||
* @param k The size of the filter (width or height). | ||
* @param s The stride size (x or y direction). | ||
* @return The convolution output size. | ||
*/ | ||
size_t PoolOutSize(const size_t size, | ||
const size_t k, | ||
const size_t s) | ||
{ | ||
return std::floor(size - 1) / s + 1; | ||
} | ||
|
||
//! Locally stored LeNet Model. | ||
FFN<OutputLayerType, InitializationRuleType> leNet; | ||
|
||
//! Locally stored width of the image. | ||
size_t inputWidth; | ||
|
||
//! Locally stored height of the image. | ||
size_t inputHeight; | ||
|
||
//! Locally stored number of channels in the image. | ||
size_t inputChannel; | ||
|
||
//! Locally stored number of output classes. | ||
size_t numClasses; | ||
|
||
//! Locally stored type of pre-trained weights. | ||
std::string weights; | ||
}; // class LeNet | ||
|
||
} // namespace ann | ||
} // namespace mlpack | ||
|
||
#include "lenet_impl.hpp" // Include implementation. | ||
|
||
#endif |
Oops, something went wrong.