-
Notifications
You must be signed in to change notification settings - Fork 41
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #20 from kartikdutt18/DarknetModel
Add Darknet model.
- Loading branch information
Showing
15 changed files
with
845 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,5 +13,6 @@ data/* | |
*.jpg | ||
*.png | ||
*.txt | ||
*.bin | ||
.travis/configs.hpp | ||
Testing/* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
cmake_minimum_required(VERSION 3.1.0 FATAL_ERROR) | ||
project(ensmallen_utils) | ||
|
||
set(DIR_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/) | ||
|
||
set(SOURCES | ||
print_metric.hpp | ||
periodic_save.hpp) | ||
|
||
foreach(file ${SOURCES}) | ||
set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file}) | ||
endforeach() | ||
|
||
# Append sources (with directory name) to list of all models sources (used at | ||
# the parent scope). | ||
set(DIRS ${DIRS} ${DIR_SRCS} PARENT_SCOPE) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
/** | ||
* @file utils.hpp | ||
* @author Kartik Dutt | ||
* | ||
* Definition of Periodic Save utility functions. | ||
* | ||
* mlpack is free software; you may redistribute it and/or modify it under the | ||
* terms of the 3-clause BSD license. You should have received a copy of the | ||
* 3-clause BSD license along with mlpack. If not, see | ||
* http://www.opensource.org/licenses/BSD-3-Clause for more information. | ||
*/ | ||
|
||
#ifndef ENSMALLEN_CALLBACKS_PERIODIC_SAVE_HPP | ||
#define ENSMALLEN_CALLBACKS_PERIODIC_SAVE_HPP | ||
|
||
#include <ensmallen.hpp> | ||
#include <mlpack/core.hpp> | ||
|
||
namespace ens { | ||
|
||
/** | ||
* Saves model being trained periodically. | ||
* | ||
* @tparam ANNType Type of model which will be used for evaluating metric. | ||
*/ | ||
template<typename AnnType> | ||
class PeriodicSave | ||
{ | ||
public: | ||
/** | ||
* Constructor for PeriodicSave class. | ||
* | ||
* @param network Network type which will be saved periodically. | ||
* @param filePath Base path / folder where weights will be saved. | ||
* @param modelPrefix Weights will be stored as | ||
* modelPrefix_epoch_loss.bin. | ||
* @param period Period after which the model will be saved. | ||
* @param silent Boolean to determine whether or not to print saving | ||
* of model. | ||
* @param output Outputstream where output will be directed. | ||
*/ | ||
PeriodicSave(AnnType& network, | ||
const std::string filePath = "./", | ||
const std::string modelPrefix = "model", | ||
const size_t period = 1, | ||
const bool silent = false, | ||
std::ostream& output = arma::get_cout_stream()) : | ||
network(network), | ||
filePath(filePath), | ||
modelPrefix(modelPrefix), | ||
period(period), | ||
silent(silent), | ||
output(output) | ||
{ | ||
// Nothing to do here. | ||
} | ||
|
||
template<typename OptimizerType, typename FunctionType, typename MatType> | ||
bool EndEpoch(OptimizerType& /* optimizer */, | ||
FunctionType& /* function */, | ||
const MatType& /* coordinates */, | ||
const size_t epoch, | ||
const double objective) | ||
{ | ||
if (epoch % period == 0) | ||
{ | ||
std::string objectiveString = std::to_string(objective); | ||
std::replace(objectiveString.begin(), objectiveString.end(), '.', '_'); | ||
std::string modelName = modelPrefix + "_" + std::to_string(epoch) + "_" + | ||
objectiveString; | ||
mlpack::data::Save(filePath + modelName + ".bin", modelPrefix, network); | ||
if (!silent) | ||
output << "Model saved as " << modelName << std::endl; | ||
} | ||
|
||
return false; | ||
} | ||
|
||
private: | ||
// Reference to the model which will be used for evaluated using the metric. | ||
AnnType& network; | ||
|
||
// Locally held string that depicts path for saving the model. | ||
std::string filePath; | ||
|
||
// Locally held string that depicts the prefix name of model being trained. | ||
std::string modelPrefix; | ||
|
||
// Period to save the model. | ||
size_t period; | ||
|
||
// Locally held boolean to determine whether to print success / failure output | ||
// when model is saved. | ||
bool silent; | ||
|
||
// The output stream that all data is to be sent to; example: std::cout. | ||
std::ostream& output; | ||
}; | ||
|
||
} // namespace ens | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
/** | ||
* @file utils.hpp | ||
* @author Kartik Dutt | ||
* | ||
* Definition of PrintMetric class. | ||
* | ||
* mlpack is free software; you may redistribute it and/or modify it under the | ||
* terms of the 3-clause BSD license. You should have received a copy of the | ||
* 3-clause BSD license along with mlpack. If not, see | ||
* http://www.opensource.org/licenses/BSD-3-Clause for more information. | ||
*/ | ||
|
||
#ifndef ENSMALLEN_CALLBACKS_PRINT_METRIC_HPP | ||
#define ENSMALLEN_CALLBACKS_PRINT_METRIC_HPP | ||
|
||
#include <ensmallen.hpp> | ||
#include <functional> | ||
|
||
namespace ens { | ||
|
||
/** | ||
* Prints metric on training / validation set. | ||
* | ||
* @tparam ANNType Type of model which will be used for evaluating metric. | ||
* @tparam MetricType Metric class which must have static `Evaluate` function | ||
* that will be called at the end of the epoch. | ||
* @tparam InputType Arma type of dataset features. | ||
* @tparam OutputType Arma type of dataset labels. | ||
*/ | ||
template<typename AnnType, | ||
class MetricType, | ||
typename InputType = arma::mat, | ||
typename OutputType = arma::mat | ||
> | ||
class PrintMetric | ||
{ | ||
public: | ||
/** | ||
* Constructor for PrintMetric class. | ||
* @param network Network type which will be saved periodically. | ||
* @param features Input features on which model will be evaluated. | ||
* @param responses Ground truth label for the mdoel. | ||
* @param metricName Metric name which will be printed after each epoch. | ||
* @param trainData Boolean to determine whether dataset corresponds to | ||
* training data or validation data. | ||
* @param output Outputstream where output will be directed. | ||
*/ | ||
PrintMetric(AnnType &network, | ||
const InputType &features, | ||
const OutputType &responses, | ||
const std::string metricName = "metric", | ||
const bool trainData = false, | ||
std::ostream &output = arma::get_cout_stream()) : | ||
network(network), | ||
features(features), | ||
responses(responses), | ||
metricName(metricName), | ||
trainData(trainData), | ||
output(output) | ||
{ | ||
// Nothing to do here. | ||
} | ||
|
||
template<typename OptimizerType, typename FunctionType, typename MatType> | ||
bool EndEpoch(OptimizerType& /* optimizer */, | ||
FunctionType& /* function */, | ||
const MatType& /* coordinates */, | ||
const size_t /* epoch */, | ||
const double /* objective */) | ||
{ | ||
OutputType predictions; | ||
network.Predict(features, predictions); | ||
const double localObjective = MetricType::Evaluate(predictions, responses); | ||
if (!std::isnan(localObjective)) | ||
{ | ||
std::string outputString = (trainData == true) ? "Train " : "Validation "; | ||
outputString = outputString + metricName + " : " + | ||
std::to_string(localObjective); | ||
output << outputString << std::endl; | ||
} | ||
return false; | ||
} | ||
|
||
private: | ||
// Reference to the model which will be used for evaluated using the metric. | ||
AnnType& network; | ||
|
||
// Dataset which will be used for evaluating the metric. | ||
InputType features; | ||
|
||
// Dataset labels / predictions that will be used for evaluating the dataset. | ||
OutputType responses; | ||
|
||
// Locally held string that depicts the name of the metric. | ||
std::string metricName; | ||
|
||
// Locally held boolean to determin whether evaluation is done on train data | ||
// or validation data. | ||
bool trainData; | ||
|
||
// The output stream that all data is to be sent to; example: std::cout. | ||
std::ostream& output; | ||
}; | ||
|
||
} // namespace ens | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
cmake_minimum_required(VERSION 3.1.0 FATAL_ERROR) | ||
project(models) | ||
|
||
add_subdirectory(darknet) | ||
|
||
# Add directory name to sources. | ||
set(DIR_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/) | ||
|
||
foreach(file ${SOURCES}) | ||
set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file}) | ||
endforeach() | ||
|
||
set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/layer.hpp) | ||
# Append sources (with directory name) to list of all models sources (used at | ||
# the parent scope). | ||
set(DIRS ${DIRS} ${DIR_SRCS} PARENT_SCOPE) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
cmake_minimum_required(VERSION 3.1.0 FATAL_ERROR) | ||
project(darknet) | ||
|
||
set(DIR_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/) | ||
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../") | ||
|
||
set(SOURCES | ||
darknet.hpp | ||
darknet_impl.hpp | ||
) | ||
|
||
foreach(file ${SOURCES}) | ||
set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file}) | ||
endforeach() | ||
|
||
set(DIRS ${DIRS} ${DIR_SRCS} PARENT_SCOPE) |
Oops, something went wrong.