diff --git a/augmentation/augmentation.hpp b/augmentation/augmentation.hpp index 22fd8e90..eb67d36c 100644 --- a/augmentation/augmentation.hpp +++ b/augmentation/augmentation.hpp @@ -89,9 +89,16 @@ class Augmentation /** * Function to determine if augmentation has Resize function. + * @param augmentation Optional argument to check if a string has + * resize substring. */ - bool HasResizeParam() + bool HasResizeParam(const std::string& augmentation = "") { + if (augmentation.length()) + { + return augmentation.find("resize") != std::string::npos; + } + // Search in augmentation vector. return augmentations.size() <= 0 ? false : augmentations[0].find("resize") != std::string::npos; @@ -102,9 +109,12 @@ class Augmentation * * @param outWidth Output width of resized data point. * @param outHeight Output height of resized data point. + * @param augmentation String from which output width and height + * are extracted. */ void GetResizeParam(size_t& outWidth, - size_t& outHeight) + size_t& outHeight, + const std::string& augmentation) { if (!HasResizeParam()) { @@ -119,15 +129,15 @@ class Augmentation boost::regex regex{"[0-9]+"}; // Create an iterator to find matches. - boost::sregex_token_iterator matches(augmentations[0].begin(), - augmentations[0].end(), regex, 0), end; + boost::sregex_token_iterator matches(augmentation.begin(), + augmentation.end(), regex, 0), end; size_t matchesCount = std::distance(matches, end); if (matchesCount == 0) { mlpack::Log::Fatal << "Invalid size / shape in " << - augmentations[0] << std::endl; + augmentation << std::endl; } if (matchesCount == 1) diff --git a/augmentation/augmentation_impl.hpp b/augmentation/augmentation_impl.hpp index 8a8b74d0..0c6d1ee8 100644 --- a/augmentation/augmentation_impl.hpp +++ b/augmentation/augmentation_impl.hpp @@ -51,7 +51,8 @@ void Augmentation::Transform(DatasetType& dataset, size_t i = 0; if (this->HasResizeParam()) { - this->ResizeTransform(dataset); + this->ResizeTransform(dataset, datapointWidth, datapointHeight, + datapointDepth, augmentations[0]); i++; } @@ -73,10 +74,15 @@ void Augmentation::ResizeTransform( const size_t datapointDepth, const std::string& augmentation) { + if (!this->HasResizeParam(augmentation)) + { + return; + } + size_t outputWidth = 0, outputHeight = 0; // Get output width and output height. - GetResizeParam(outputWidth, outputHeight); + GetResizeParam(outputWidth, outputHeight, augmentation); // We will use mlpack's bilinear interpolation layer to // resize the input. diff --git a/tests/augmentation_tests.cpp b/tests/augmentation_tests.cpp index 85b3f3aa..37325eed 100644 --- a/tests/augmentation_tests.cpp +++ b/tests/augmentation_tests.cpp @@ -12,19 +12,46 @@ #define BOOST_TEST_DYN_LINK #include #include +#include using namespace boost::unit_test; BOOST_AUTO_TEST_SUITE(AugmentationTest); -BOOST_AUTO_TEST_CASE(REGEXTest) +BOOST_AUTO_TEST_CASE(ResizeAugmentationTest) { - // Some accepted formats. - std::string s = " resize = { 19, 112 }, \ - resize : 133, 442, resize = [12 213]"; - boost::regex expr{"[0-9]+"}; - boost::smatch what; - boost::sregex_token_iterator iter(s.begin(), s.end(), expr, 0); - boost::sregex_token_iterator end; + Augmentation<> augmentation(std::vector(1, "resize (5, 4)"), 0.2); + + // Test on a square matrix. + arma::mat input; + size_t inputWidth = 2; + size_t inputHeight = 2; + size_t depth = 1; + input.zeros(inputWidth * inputHeight * depth, 2); + + // Resize function called. + augmentation.Transform(input, inputWidth, inputHeight, depth); + + // Check correctness of input. + BOOST_REQUIRE_EQUAL(input.n_cols, 2); + BOOST_REQUIRE_EQUAL(input.n_rows, 5 * 4); + + // Test on rectangular matrix. + inputWidth = 5; + inputHeight = 7; + depth = 1; + input.zeros(inputWidth * inputHeight * depth, 2); + + // Rectangular input to sqaure output. + std::vector augmentationVector = {"horizontal-flip", + "resize : 8"}; + Augmentation<> augmentation2(augmentationVector, 0.2); + + // Resize function called. + augmentation2.Transform(input, inputWidth, inputHeight, depth); + + // Check correctness of input. + BOOST_REQUIRE_EQUAL(input.n_cols, 2); + BOOST_REQUIRE_EQUAL(input.n_rows, 8 * 8); } BOOST_AUTO_TEST_SUITE_END();