-
Notifications
You must be signed in to change notification settings - Fork 41
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add basic definition of augmentation class
- Loading branch information
1 parent
6ae432e
commit e51aaa9
Showing
5 changed files
with
180 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
cmake_minimum_required(VERSION 3.1.0 FATAL_ERROR) | ||
project(augmentation) | ||
|
||
option(DEBUG "DEBUG" OFF) | ||
|
||
set(DIR_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/) | ||
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../") | ||
|
||
set(SOURCES | ||
augmentation.hpp | ||
) | ||
|
||
foreach(file ${SOURCES}) | ||
set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file}) | ||
endforeach() | ||
|
||
# Append sources (with directory name) to list of all models sources (used at | ||
# the parent scope). | ||
set(DIRS ${DIRS} ${DIR_SRCS} PARENT_SCOPE) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
/** | ||
* @file augmentation.hpp | ||
* @author Kartik Dutt | ||
* | ||
* Definition of Augmentation class for augmenting data. | ||
* | ||
* mlpack is free software; you may redistribute it and/or modify it under the | ||
* terms of the 3-clause BSD license. You should have received a copy of the | ||
* 3-clause BSD license along with mlpack. If not, see | ||
* http://www.opensource.org/licenses/BSD-3-Clause for more information. | ||
*/ | ||
|
||
#include <mlpack/methods/ann/layer/bilinear_interpolation.hpp> | ||
#include <boost/regex.hpp> | ||
|
||
#ifndef MODELS_AUGMENTATION_HPP | ||
#define MODELS_AUGMENTATION_HPP | ||
|
||
/** | ||
* Augmentation class used to perform augmentations / transform the data. | ||
* For the list of supported augmentation, take a look at our wiki page. | ||
* | ||
* @code | ||
* Augmentation<> augmentation({"horizontal-flip", "resize = (224, 224)"}, 0.2); | ||
* augmentation.Transform(dataloader.TrainFeatures); | ||
* @endcode | ||
* | ||
* @tparam DatasetX Datatype on which augmentation will be done. | ||
*/ | ||
class Augmentation | ||
{ | ||
public: | ||
//! Create the augmenation class object. | ||
Augmentation(); | ||
|
||
/** | ||
* Constructor for augmentation class. | ||
* | ||
* @param augmentation List of strings containing one of the supported | ||
* augmentation. | ||
* @param augmentationProbability Probability of applying augmentation on | ||
* the dataset. | ||
* NOTE : This doesn't apply to augmentations | ||
* such as resize. | ||
* @param batches Boolean to determine if input is a single data point or | ||
* a batch. Defaults to true. | ||
* NOTE : If true, each data point must be represented as a | ||
* seperate column. | ||
*/ | ||
Augmentation(const std::vector<std::string>& augmentation, | ||
const double augmentationProbability); | ||
|
||
/** | ||
*/ | ||
template<typename DatasetType = arma::mat> | ||
void Transform(DatasetType& dataset); | ||
|
||
template<typename DatasetType = arma::mat> | ||
void ResizeTransform(DatasetType& dataset); | ||
|
||
template <typename DatasetType = arma::mat> | ||
void HorizontalFlipTransform(DatasetType &dataset); | ||
|
||
template<typename DatasetType = arma::mat> | ||
void VerticalFlipTransform(DatasetType& dataset); | ||
|
||
|
||
private: | ||
/** | ||
* Function to determine if augmentation has Resize function. | ||
*/ | ||
bool HasResizeParam() | ||
{ | ||
return augmentations.size() <= 0 ? false : | ||
augmentations[0].find("resize") != std::string::npos ; | ||
} | ||
|
||
/** | ||
* Sets size of output width and output height of the new data. | ||
* | ||
* @param outWidth Output width of resized data point. | ||
* @param outHeight Output height of resized data point. | ||
*/ | ||
void GetResizeParam(size_t& outWidth, size_t& outHeight) | ||
{ | ||
if (!HasResizeParam()) | ||
{ | ||
return; | ||
} | ||
|
||
outWidth = -1; | ||
outHeight = -1; | ||
|
||
// Use regex to find one / two numbers. If only one provided | ||
// set output width equal to output height. | ||
boost::regex regex{"[0-9]+"}; | ||
|
||
// Create an iterator to find matches. | ||
boost::sregex_token_iterator matches(augmentations[0].begin(), | ||
augmentations[0].end(), regex, 0), end; | ||
|
||
size_t matchesCount = std::distance(matches, end); | ||
|
||
if (matchesCount == 0) | ||
{ | ||
mlpack::Log::Fatal << "Invalid size / shape in " << | ||
augmentations[0] << std::endl; | ||
} | ||
|
||
if (matchesCount == 1) | ||
{ | ||
outWidth = std::stoi(*matches); | ||
outHeight = outWidth; | ||
} | ||
else | ||
{ | ||
outWidth = std::stoi(*matches); | ||
matches++; | ||
outHeight = std::stoi(*matches); | ||
} | ||
} | ||
|
||
//! Locally held augmentations / transforms that need to be applied. | ||
std::vector<std::string> augmentations; | ||
|
||
//! Locally held value of augmentation probability. | ||
double augmentationProbability; | ||
}; | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
/** | ||
* @file augmentation.cpp | ||
* @author Kartik Dutt | ||
* | ||
* Tests for various functionalities of utils. | ||
* | ||
* mlpack is free software; you may redistribute it and/or modify it under the | ||
* terms of the 3-clause BSD license. You should have received a copy of the | ||
* 3-clause BSD license along with mlpack. If not, see | ||
* http://www.opensource.org/licenses/BSD-3-Clause for more information. | ||
*/ | ||
#define BOOST_TEST_DYN_LINK | ||
#include <boost/regex.hpp> | ||
#include <boost/test/unit_test.hpp> | ||
using namespace boost::unit_test; | ||
|
||
BOOST_AUTO_TEST_SUITE(AugmentationTest); | ||
|
||
BOOST_AUTO_TEST_CASE(REGEXTest) | ||
{ | ||
std::string s = " resize = { 19, 112 }, resize : 133,442, resize = [12 213]"; | ||
boost::regex expr{"[0-9]+"}; | ||
boost::smatch what; | ||
boost::sregex_token_iterator iter(s.begin(), s.end(), expr, 0); | ||
boost::sregex_token_iterator end; | ||
} | ||
|
||
BOOST_AUTO_TEST_SUITE_END(); |