From e51aaa97db5f0b7f07db540e427c9dc4e1351792 Mon Sep 17 00:00:00 2001 From: kartikdutt18 Date: Sat, 30 May 2020 15:50:33 +0530 Subject: [PATCH] Add basic definition of augmentation class --- CMakeLists.txt | 1 + augmentation/CMakeLists.txt | 19 +++++ augmentation/augmentation.hpp | 130 ++++++++++++++++++++++++++++++++++ tests/CMakeLists.txt | 2 + tests/augmentation_tests.cpp | 28 ++++++++ 5 files changed, 180 insertions(+) create mode 100644 augmentation/CMakeLists.txt create mode 100644 augmentation/augmentation.hpp create mode 100644 tests/augmentation_tests.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 8cb103c2..d5a32184 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -164,6 +164,7 @@ find_package(Boost 1.49 COMPONENTS filesystem system + regex program_options serialization unit_test_framework diff --git a/augmentation/CMakeLists.txt b/augmentation/CMakeLists.txt new file mode 100644 index 00000000..40510d03 --- /dev/null +++ b/augmentation/CMakeLists.txt @@ -0,0 +1,19 @@ +cmake_minimum_required(VERSION 3.1.0 FATAL_ERROR) +project(augmentation) + +option(DEBUG "DEBUG" OFF) + +set(DIR_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/) +include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../") + +set(SOURCES + augmentation.hpp +) + +foreach(file ${SOURCES}) + set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file}) +endforeach() + +# Append sources (with directory name) to list of all models sources (used at +# the parent scope). +set(DIRS ${DIRS} ${DIR_SRCS} PARENT_SCOPE) diff --git a/augmentation/augmentation.hpp b/augmentation/augmentation.hpp new file mode 100644 index 00000000..78819c1a --- /dev/null +++ b/augmentation/augmentation.hpp @@ -0,0 +1,130 @@ +/** + * @file augmentation.hpp + * @author Kartik Dutt + * + * Definition of Augmentation class for augmenting data. + * + * mlpack is free software; you may redistribute it and/or modify it under the + * terms of the 3-clause BSD license. You should have received a copy of the + * 3-clause BSD license along with mlpack. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ + +#include +#include + +#ifndef MODELS_AUGMENTATION_HPP +#define MODELS_AUGMENTATION_HPP + +/** + * Augmentation class used to perform augmentations / transform the data. + * For the list of supported augmentation, take a look at our wiki page. + * + * @code + * Augmentation<> augmentation({"horizontal-flip", "resize = (224, 224)"}, 0.2); + * augmentation.Transform(dataloader.TrainFeatures); + * @endcode + * + * @tparam DatasetX Datatype on which augmentation will be done. + */ +class Augmentation +{ + public: + //! Create the augmenation class object. + Augmentation(); + + /** + * Constructor for augmentation class. + * + * @param augmentation List of strings containing one of the supported + * augmentation. + * @param augmentationProbability Probability of applying augmentation on + * the dataset. + * NOTE : This doesn't apply to augmentations + * such as resize. + * @param batches Boolean to determine if input is a single data point or + * a batch. Defaults to true. + * NOTE : If true, each data point must be represented as a + * seperate column. + */ + Augmentation(const std::vector& augmentation, + const double augmentationProbability); + + /** + */ + template + void Transform(DatasetType& dataset); + + template + void ResizeTransform(DatasetType& dataset); + + template + void HorizontalFlipTransform(DatasetType &dataset); + + template + void VerticalFlipTransform(DatasetType& dataset); + + + private: + /** + * Function to determine if augmentation has Resize function. + */ + bool HasResizeParam() + { + return augmentations.size() <= 0 ? false : + augmentations[0].find("resize") != std::string::npos ; + } + + /** + * Sets size of output width and output height of the new data. + * + * @param outWidth Output width of resized data point. + * @param outHeight Output height of resized data point. + */ + void GetResizeParam(size_t& outWidth, size_t& outHeight) + { + if (!HasResizeParam()) + { + return; + } + + outWidth = -1; + outHeight = -1; + + // Use regex to find one / two numbers. If only one provided + // set output width equal to output height. + boost::regex regex{"[0-9]+"}; + + // Create an iterator to find matches. + boost::sregex_token_iterator matches(augmentations[0].begin(), + augmentations[0].end(), regex, 0), end; + + size_t matchesCount = std::distance(matches, end); + + if (matchesCount == 0) + { + mlpack::Log::Fatal << "Invalid size / shape in " << + augmentations[0] << std::endl; + } + + if (matchesCount == 1) + { + outWidth = std::stoi(*matches); + outHeight = outWidth; + } + else + { + outWidth = std::stoi(*matches); + matches++; + outHeight = std::stoi(*matches); + } + } + + //! Locally held augmentations / transforms that need to be applied. + std::vector augmentations; + + //! Locally held value of augmentation probability. + double augmentationProbability; +}; + +#endif \ No newline at end of file diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 4414aa31..ded2245a 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -7,6 +7,7 @@ set(MODEL_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/) include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../") add_executable(models_test + augmentation_tests.cpp dataloader_tests.cpp utils_tests.cpp ) @@ -19,6 +20,7 @@ target_link_libraries(models_test ${Boost_UNIT_TEST_FRAMEWORK_LIBRARY} ${Boost_SYSTEM_LIBRARY} ${Boost_SERIALIZATION_LIBRARY} + ${Boost_REGEX_LIBRARY} ${MLPACK_LIBRARIES} ) diff --git a/tests/augmentation_tests.cpp b/tests/augmentation_tests.cpp new file mode 100644 index 00000000..2d10a218 --- /dev/null +++ b/tests/augmentation_tests.cpp @@ -0,0 +1,28 @@ +/** + * @file augmentation.cpp + * @author Kartik Dutt + * + * Tests for various functionalities of utils. + * + * mlpack is free software; you may redistribute it and/or modify it under the + * terms of the 3-clause BSD license. You should have received a copy of the + * 3-clause BSD license along with mlpack. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ +#define BOOST_TEST_DYN_LINK +#include +#include +using namespace boost::unit_test; + +BOOST_AUTO_TEST_SUITE(AugmentationTest); + +BOOST_AUTO_TEST_CASE(REGEXTest) +{ + std::string s = " resize = { 19, 112 }, resize : 133,442, resize = [12 213]"; + boost::regex expr{"[0-9]+"}; + boost::smatch what; + boost::sregex_token_iterator iter(s.begin(), s.end(), expr, 0); + boost::sregex_token_iterator end; +} + +BOOST_AUTO_TEST_SUITE_END();