Skip to content

Commit

Permalink
Add basic definition of augmentation class
Browse files Browse the repository at this point in the history
  • Loading branch information
kartikdutt18 committed May 30, 2020
1 parent 6ae432e commit e51aaa9
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 0 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ find_package(Boost 1.49
COMPONENTS
filesystem
system
regex
program_options
serialization
unit_test_framework
Expand Down
19 changes: 19 additions & 0 deletions augmentation/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
cmake_minimum_required(VERSION 3.1.0 FATAL_ERROR)
project(augmentation)

option(DEBUG "DEBUG" OFF)

set(DIR_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/)
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../")

set(SOURCES
augmentation.hpp
)

foreach(file ${SOURCES})
set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file})
endforeach()

# Append sources (with directory name) to list of all models sources (used at
# the parent scope).
set(DIRS ${DIRS} ${DIR_SRCS} PARENT_SCOPE)
130 changes: 130 additions & 0 deletions augmentation/augmentation.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/**
* @file augmentation.hpp
* @author Kartik Dutt
*
* Definition of Augmentation class for augmenting data.
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/

#include <mlpack/methods/ann/layer/bilinear_interpolation.hpp>
#include <boost/regex.hpp>

#ifndef MODELS_AUGMENTATION_HPP
#define MODELS_AUGMENTATION_HPP

/**
* Augmentation class used to perform augmentations / transform the data.
* For the list of supported augmentation, take a look at our wiki page.
*
* @code
* Augmentation<> augmentation({"horizontal-flip", "resize = (224, 224)"}, 0.2);
* augmentation.Transform(dataloader.TrainFeatures);
* @endcode
*
* @tparam DatasetX Datatype on which augmentation will be done.
*/
class Augmentation
{
public:
//! Create the augmenation class object.
Augmentation();

/**
* Constructor for augmentation class.
*
* @param augmentation List of strings containing one of the supported
* augmentation.
* @param augmentationProbability Probability of applying augmentation on
* the dataset.
* NOTE : This doesn't apply to augmentations
* such as resize.
* @param batches Boolean to determine if input is a single data point or
* a batch. Defaults to true.
* NOTE : If true, each data point must be represented as a
* seperate column.
*/
Augmentation(const std::vector<std::string>& augmentation,
const double augmentationProbability);

/**
*/
template<typename DatasetType = arma::mat>
void Transform(DatasetType& dataset);

template<typename DatasetType = arma::mat>
void ResizeTransform(DatasetType& dataset);

template <typename DatasetType = arma::mat>
void HorizontalFlipTransform(DatasetType &dataset);

template<typename DatasetType = arma::mat>
void VerticalFlipTransform(DatasetType& dataset);


private:
/**
* Function to determine if augmentation has Resize function.
*/
bool HasResizeParam()
{
return augmentations.size() <= 0 ? false :
augmentations[0].find("resize") != std::string::npos ;
}

/**
* Sets size of output width and output height of the new data.
*
* @param outWidth Output width of resized data point.
* @param outHeight Output height of resized data point.
*/
void GetResizeParam(size_t& outWidth, size_t& outHeight)
{
if (!HasResizeParam())
{
return;
}

outWidth = -1;
outHeight = -1;

// Use regex to find one / two numbers. If only one provided
// set output width equal to output height.
boost::regex regex{"[0-9]+"};

// Create an iterator to find matches.
boost::sregex_token_iterator matches(augmentations[0].begin(),
augmentations[0].end(), regex, 0), end;

size_t matchesCount = std::distance(matches, end);

if (matchesCount == 0)
{
mlpack::Log::Fatal << "Invalid size / shape in " <<
augmentations[0] << std::endl;
}

if (matchesCount == 1)
{
outWidth = std::stoi(*matches);
outHeight = outWidth;
}
else
{
outWidth = std::stoi(*matches);
matches++;
outHeight = std::stoi(*matches);
}
}

//! Locally held augmentations / transforms that need to be applied.
std::vector<std::string> augmentations;

//! Locally held value of augmentation probability.
double augmentationProbability;
};

#endif
2 changes: 2 additions & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ set(MODEL_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/)
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../")

add_executable(models_test
augmentation_tests.cpp
dataloader_tests.cpp
utils_tests.cpp
)
Expand All @@ -19,6 +20,7 @@ target_link_libraries(models_test
${Boost_UNIT_TEST_FRAMEWORK_LIBRARY}
${Boost_SYSTEM_LIBRARY}
${Boost_SERIALIZATION_LIBRARY}
${Boost_REGEX_LIBRARY}
${MLPACK_LIBRARIES}
)

Expand Down
28 changes: 28 additions & 0 deletions tests/augmentation_tests.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/**
* @file augmentation.cpp
* @author Kartik Dutt
*
* Tests for various functionalities of utils.
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/
#define BOOST_TEST_DYN_LINK
#include <boost/regex.hpp>
#include <boost/test/unit_test.hpp>
using namespace boost::unit_test;

BOOST_AUTO_TEST_SUITE(AugmentationTest);

BOOST_AUTO_TEST_CASE(REGEXTest)
{
std::string s = " resize = { 19, 112 }, resize : 133,442, resize = [12 213]";
boost::regex expr{"[0-9]+"};
boost::smatch what;
boost::sregex_token_iterator iter(s.begin(), s.end(), expr, 0);
boost::sregex_token_iterator end;
}

BOOST_AUTO_TEST_SUITE_END();

0 comments on commit e51aaa9

Please sign in to comment.