Skip to content

Commit

Permalink
Add tests for augmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
kartikdutt18 committed Jun 1, 2020
1 parent 2727851 commit 7b3c433
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 15 deletions.
20 changes: 15 additions & 5 deletions augmentation/augmentation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,16 @@ class Augmentation

/**
* Function to determine if augmentation has Resize function.
* @param augmentation Optional argument to check if a string has
* resize substring.
*/
bool HasResizeParam()
bool HasResizeParam(const std::string& augmentation = "")
{
if (augmentation.length())
{
return augmentation.find("resize") != std::string::npos;
}

// Search in augmentation vector.
return augmentations.size() <= 0 ? false :
augmentations[0].find("resize") != std::string::npos;
Expand All @@ -102,9 +109,12 @@ class Augmentation
*
* @param outWidth Output width of resized data point.
* @param outHeight Output height of resized data point.
* @param augmentation String from which output width and height
* are extracted.
*/
void GetResizeParam(size_t& outWidth,
size_t& outHeight)
size_t& outHeight,
const std::string& augmentation)
{
if (!HasResizeParam())
{
Expand All @@ -119,15 +129,15 @@ class Augmentation
boost::regex regex{"[0-9]+"};

// Create an iterator to find matches.
boost::sregex_token_iterator matches(augmentations[0].begin(),
augmentations[0].end(), regex, 0), end;
boost::sregex_token_iterator matches(augmentation.begin(),
augmentation.end(), regex, 0), end;

size_t matchesCount = std::distance(matches, end);

if (matchesCount == 0)
{
mlpack::Log::Fatal << "Invalid size / shape in " <<
augmentations[0] << std::endl;
augmentation << std::endl;
}

if (matchesCount == 1)
Expand Down
10 changes: 8 additions & 2 deletions augmentation/augmentation_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ void Augmentation<DatasetType>::Transform(DatasetType& dataset,
size_t i = 0;
if (this->HasResizeParam())
{
this->ResizeTransform(dataset);
this->ResizeTransform(dataset, datapointWidth, datapointHeight,
datapointDepth, augmentations[0]);
i++;
}

Expand All @@ -73,10 +74,15 @@ void Augmentation<DatasetType>::ResizeTransform(
const size_t datapointDepth,
const std::string& augmentation)
{
if (!this->HasResizeParam(augmentation))
{
return;
}

size_t outputWidth = 0, outputHeight = 0;

// Get output width and output height.
GetResizeParam(outputWidth, outputHeight);
GetResizeParam(outputWidth, outputHeight, augmentation);

// We will use mlpack's bilinear interpolation layer to
// resize the input.
Expand Down
43 changes: 35 additions & 8 deletions tests/augmentation_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,46 @@
#define BOOST_TEST_DYN_LINK
#include <boost/regex.hpp>
#include <boost/test/unit_test.hpp>
#include <augmentation/augmentation.hpp>
using namespace boost::unit_test;

BOOST_AUTO_TEST_SUITE(AugmentationTest);

BOOST_AUTO_TEST_CASE(REGEXTest)
BOOST_AUTO_TEST_CASE(ResizeAugmentationTest)
{
// Some accepted formats.
std::string s = " resize = { 19, 112 }, \
resize : 133, 442, resize = [12 213]";
boost::regex expr{"[0-9]+"};
boost::smatch what;
boost::sregex_token_iterator iter(s.begin(), s.end(), expr, 0);
boost::sregex_token_iterator end;
Augmentation<> augmentation(std::vector<std::string>(1, "resize (5, 4)"), 0.2);

// Test on a square matrix.
arma::mat input;
size_t inputWidth = 2;
size_t inputHeight = 2;
size_t depth = 1;
input.zeros(inputWidth * inputHeight * depth, 2);

// Resize function called.
augmentation.Transform(input, inputWidth, inputHeight, depth);

// Check correctness of input.
BOOST_REQUIRE_EQUAL(input.n_cols, 2);
BOOST_REQUIRE_EQUAL(input.n_rows, 5 * 4);

// Test on rectangular matrix.
inputWidth = 5;
inputHeight = 7;
depth = 1;
input.zeros(inputWidth * inputHeight * depth, 2);

// Rectangular input to sqaure output.
std::vector<std::string> augmentationVector = {"horizontal-flip",
"resize : 8"};
Augmentation<> augmentation2(augmentationVector, 0.2);

// Resize function called.
augmentation2.Transform(input, inputWidth, inputHeight, depth);

// Check correctness of input.
BOOST_REQUIRE_EQUAL(input.n_cols, 2);
BOOST_REQUIRE_EQUAL(input.n_rows, 8 * 8);
}

BOOST_AUTO_TEST_SUITE_END();

0 comments on commit 7b3c433

Please sign in to comment.