diff --git a/augmentation/augmentation.hpp b/augmentation/augmentation.hpp index 247013c6..22fd8e90 100644 --- a/augmentation/augmentation.hpp +++ b/augmentation/augmentation.hpp @@ -25,8 +25,9 @@ * augmentation.Transform(dataloader.TrainFeatures); * @endcode * - * @tparam DatasetX Datatype on which augmentation will be done. + * @tparam DatasetType Datatype on which augmentation will be done. */ +template class Augmentation { public: @@ -47,28 +48,53 @@ class Augmentation const double augmentationProbability); /** + * Applies augmentation to the passed dataset. + * + * @param dataset Dataset on which augmentation will be applied. + * @param datapointWidth Width of a single data point i.e. + * Since each column represents a seperate data + * point. + * @param datapointHeight Height of a single data point. + * @param datapointDepth Depth of a single data point. For 2-dimensional + * data point, set it to 1. Defaults to 1. */ - template - void Transform(DatasetType& dataset); - - template - void ResizeTransform(DatasetType& dataset); - - template - void HorizontalFlipTransform(DatasetType &dataset); - - template - void VerticalFlipTransform(DatasetType& dataset); + void Transform(DatasetType& dataset, + const size_t datapointWidth, + const size_t datapointHeight, + const size_t datapointDepth = 1); + /** + * Applies resize transform to the entire dataset. + * + * @param dataset Dataset on which augmentation will be applied. + * @param datapointWidth Width of a single data point i.e. + * Since each column represents a seperate data + * point. + * @param datapointHeight Height of a single data point. + * @param datapointDepth Depth of a single data point. For 2-dimensional + * data point, set it to 1. Defaults to 1. + * @param augmentation String containing the transform. + */ + void ResizeTransform(DatasetType& dataset, + const size_t datapointWidth, + const size_t datapointHeight, + const size_t datapointDepth, + const std::string& augmentation); private: + /** + * Initializes augmentation map for the class. + */ + void InitializeAugmentationMap(); + /** * Function to determine if augmentation has Resize function. */ bool HasResizeParam() { + // Search in augmentation vector. return augmentations.size() <= 0 ? false : - augmentations[0].find("resize") != std::string::npos ; + augmentations[0].find("resize") != std::string::npos; } /** @@ -77,15 +103,16 @@ class Augmentation * @param outWidth Output width of resized data point. * @param outHeight Output height of resized data point. */ - void GetResizeParam(size_t& outWidth, size_t& outHeight) + void GetResizeParam(size_t& outWidth, + size_t& outHeight) { if (!HasResizeParam()) { return; } - outWidth = -1; - outHeight = -1; + outWidth = 0; + outHeight = 0; // Use regex to find one / two numbers. If only one provided // set output width equal to output height. @@ -121,8 +148,17 @@ class Augmentation //! Locally held value of augmentation probability. double augmentationProbability; + + //! Locally help map for mapping functions and strings. + std::unordered_map augmentationMap; + + // The dataloader class should have access to internal functions of + // the dataloader. + template + friend class DataLoader; }; #include "augmentation_impl.hpp" // Include implementation. -#endif \ No newline at end of file +#endif diff --git a/augmentation/augmentation_impl.hpp b/augmentation/augmentation_impl.hpp index 0efbfb87..8a8b74d0 100644 --- a/augmentation/augmentation_impl.hpp +++ b/augmentation/augmentation_impl.hpp @@ -15,17 +15,20 @@ #ifndef MODELS_AUGMENTATION_IMPL_HPP #define MODELS_AUGMENTATION_IMPL_HPP -Augmentation::Augmentation() : +template +Augmentation::Augmentation() : augmentations(std::vector()), augmentationProbability(0.2) { // Nothing to do here. } -Augmentation::Augmentation(const std::vector& augmentations, - const double augmentationProbability) : - augmentations(augmentations), - augmentationProbability(augmentationProbability) +template +Augmentation::Augmentation( + const std::vector& augmentations, + const double augmentationProbability) : + augmentations(augmentations), + augmentationProbability(augmentationProbability) { // Sort the vector to place resize parameter to the front of the string. // This prevents constant look ups for resize. @@ -34,6 +37,63 @@ Augmentation::Augmentation(const std::vector& augmentations, { return str1.find("resize") != std::string::npos; }); + + // Fill augmentation map with supported augmentations other than resize. + InitializeAugmentationMap(); +} + +template +void Augmentation::Transform(DatasetType& dataset, + const size_t datapointWidth, + const size_t datapointHeight, + const size_t datapointDepth) +{ + size_t i = 0; + if (this->HasResizeParam()) + { + this->ResizeTransform(dataset); + i++; + } + + for (; i < augmentations.size(); i++) + { + if (augmentationMap.count(augmentations[i])) + { + augmentationMap[augmentations[i]](dataset, datapointWidth, + datapointHeight, datapointDepth, augmentations[i]); + } + } +} + +template +void Augmentation::ResizeTransform( + DatasetType& dataset, + const size_t datapointWidth, + const size_t datapointHeight, + const size_t datapointDepth, + const std::string& augmentation) +{ + size_t outputWidth = 0, outputHeight = 0; + + // Get output width and output height. + GetResizeParam(outputWidth, outputHeight); + + // We will use mlpack's bilinear interpolation layer to + // resize the input. + mlpack::ann::BilinearInterpolation resizeLayer( + datapointWidth, datapointHeight, outputWidth, outputHeight, + datapointDepth); + + // Not sure how to avoid a copy here. + DatasetType output; + resizeLayer.Forward(dataset, output); + dataset = output; +} + +template +void Augmentation::InitializeAugmentationMap() +{ + // Fill the map here. } -#endif \ No newline at end of file +#endif diff --git a/dataloader/dataloader.hpp b/dataloader/dataloader.hpp index 3a83d100..1abc1859 100644 --- a/dataloader/dataloader.hpp +++ b/dataloader/dataloader.hpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include #include @@ -158,6 +159,9 @@ class DataLoader void LoadObjectDetectionDataset(const std::string& pathToAnnotations, const std::string& pathToImages, const std::vector& classes, + const std::vector& augmentation = + std::vector(), + const double augmentationProbability = 0.2, const bool absolutePath = false, const std::string& baseXMLTag = "annotation", const std::string& imageNameXMLTag = diff --git a/dataloader/dataloader_impl.hpp b/dataloader/dataloader_impl.hpp index 99a19787..7b1c2636 100644 --- a/dataloader/dataloader_impl.hpp +++ b/dataloader/dataloader_impl.hpp @@ -153,6 +153,8 @@ template< >::LoadObjectDetectionDataset(const std::string& pathToAnnotations, const std::string& pathToImages, const std::vector& classes, + const std::vector& augmentations, + const double augmentationProbability, const bool absolutePath, const std::string& baseXMLTag, const std::string& imageNameXMLTag, @@ -165,6 +167,8 @@ template< const std::string& x2XMLTag, const std::string& y2XMLTag) { + Augmentation augmentation(augmentations, augmentationProbability); + std::vector annotationsDirectory, imagesDirectory; // Fill the directory. @@ -209,26 +213,32 @@ template< continue; } - // Get the size of image to create image info required by mlpack::data::Load function. - boost::property_tree::ptree sizeInformation = annotation.get_child(sizeXMLTag); - size_t imageWidth = std::stoi(sizeInformation.get_child("width").data()); - size_t imageHeight = std::stoi(sizeInformation.get_child("height").data()); - size_t imageDepth = std::stoi(sizeInformation.get_child("depth").data()); + // Get the size of image to create image info required + // by mlpack::data::Load function. + boost::property_tree::ptree sizeInfo = annotation.get_child(sizeXMLTag); + size_t imageWidth = std::stoi(sizeInfo.get_child("width").data()); + size_t imageHeight = std::stoi(sizeInfo.get_child("height").data()); + size_t imageDepth = std::stoi(sizeInfo.get_child("depth").data()); mlpack::data::ImageInfo imageInfo(imageWidth, imageHeight, imageDepth); - // TODO: Resize the image here. - // Load the image. - // The image loaded here will be in column format i.e. Output will be matrix with the - // following shape {1, cols * rows * slices} in column major format. + // The image loaded here will be in column format i.e. Output will + // be matrix with the following shape {1, cols * rows * slices} in + // column major format. DatasetX image; mlpack::data::Load(pathToImages + imgName, image, imageInfo); + if (augmentation.HasResizeParam()) + { + augmentation.ResizeTransform(image, imageWidth, imageHeight, imageDepth, + augmentation.augmentations[0]); + } + // Iterate over all object in annotation. BOOST_FOREACH(boost::property_tree::ptree::value_type const& object, annotation) { - arma::uvec predictions(5); + arma::vec predictions(5); // Iterate over property of the object to get class label and // bounding box coordinates. if (object.first == objectXMLTag) @@ -249,6 +259,10 @@ template< std::stoi(coordinate.second.data()); } } + + // Add object to training set. + trainFeatures.insert_cols(0, image); + trainLabels.insert_cols(0, predictions); } } } diff --git a/tests/augmentation_tests.cpp b/tests/augmentation_tests.cpp index 2d10a218..85b3f3aa 100644 --- a/tests/augmentation_tests.cpp +++ b/tests/augmentation_tests.cpp @@ -18,7 +18,9 @@ BOOST_AUTO_TEST_SUITE(AugmentationTest); BOOST_AUTO_TEST_CASE(REGEXTest) { - std::string s = " resize = { 19, 112 }, resize : 133,442, resize = [12 213]"; + // Some accepted formats. + std::string s = " resize = { 19, 112 }, \ + resize : 133, 442, resize = [12 213]"; boost::regex expr{"[0-9]+"}; boost::smatch what; boost::sregex_token_iterator iter(s.begin(), s.end(), expr, 0); diff --git a/tests/dataloader_tests.cpp b/tests/dataloader_tests.cpp index bd7543e9..53b24b7a 100644 --- a/tests/dataloader_tests.cpp +++ b/tests/dataloader_tests.cpp @@ -59,13 +59,32 @@ BOOST_AUTO_TEST_CASE(CSVDataLoaderTest) */ BOOST_AUTO_TEST_CASE(MNISTDataLoaderTest) { - /** + /* DataLoader<> dataloader("mnist", true, 0.80); + + // Check for correct dimensions. + BOOST_REQUIRE_EQUAL(dataloader.TrainFeatures().n_rows, 784); + BOOST_REQUIRE_EQUAL(dataloader.TestFeatures().n_rows, 784); + BOOST_REQUIRE_EQUAL(dataloader.ValidFeatures().n_rows, 784); + // Check for correct dimensions. - BOOST_REQUIRE_EQUAL(dataloader.TrainFeatures().n_cols, 784); - BOOST_REQUIRE_EQUAL(dataloader.TestFeatures().n_cols, 784); - BOOST_REQUIRE_EQUAL(dataloader.ValidFeatures().n_cols, 784); - BOOST_REQUIRE_EQUAL(dataloader.TrainFeatures().n_rows, 33600); + BOOST_REQUIRE_EQUAL(dataloader.TrainFeatures().n_cols, 8400); + BOOST_REQUIRE_EQUAL(dataloader.ValidFeatures().n_cols, 33600); + BOOST_REQUIRE_EQUAL(dataloader.TestFeatures().n_cols, 28000); + + // Check if we can access both features and labels using + // TrainSet tuple and ValidSet tuple. + BOOST_REQUIRE_EQUAL(std::get<0>(dataloader.TrainSet()).n_cols, 8400); + BOOST_REQUIRE_EQUAL(std::get<1>(dataloader.TrainSet()).n_rows, 1); + BOOST_REQUIRE_EQUAL(std::get<0>(dataloader.ValidSet()).n_cols, 33600); + BOOST_REQUIRE_EQUAL(std::get<1>(dataloader.ValidSet()).n_rows, 1); + + // Clean up. + Utils::RemoveFile("./../data/mnist-dataset/mnist_all.csv"); + Utils::RemoveFile("./../data/mnist-dataset/mnist_all_centroids.csv"); + Utils::RemoveFile("./../data/mnist-dataset/mnist_train.csv"); + Utils::RemoveFile("./../data/mnist-dataset/mnist_test.csv"); + Utils::RemoveFile("./../data/mnist.tar.gz"); */ } @@ -76,7 +95,8 @@ BOOST_AUTO_TEST_CASE(ObjectDetectionDataLoader) { DataLoader<> dataloader; dataloader.LoadObjectDetectionDataset("./../data/annotations/", - "./../data/images/", {"person", "foot", "aeroplane", "head", "hand"}); + "./../data/images/", {"person", "foot", "aeroplane", "head", "hand"}, + {"resize 64, 64"}); } BOOST_AUTO_TEST_SUITE_END(); diff --git a/tests/utils_tests.cpp b/tests/utils_tests.cpp index aa04fec7..4a314496 100644 --- a/tests/utils_tests.cpp +++ b/tests/utils_tests.cpp @@ -76,15 +76,19 @@ BOOST_AUTO_TEST_CASE(RemoveFileTest) BOOST_AUTO_TEST_CASE(ExtractFilesTest) { - Utils::DownloadFile("/datasets/mnist.tar.gz", "./../data/mnist.tar.gz", "", - false, true, "www.mlpack.org", true, "./../data/"); + std::vector vec; - BOOST_REQUIRE(Utils::PathExists("./../data/mnist_all.csv")); - BOOST_REQUIRE(Utils::PathExists("./../data/mnist.tar.gz")); + Utils::DownloadFile("/datasets/USCensus1990.tar.gz", + "./../data/USCensus1990.tar.gz", "", false, true, + "www.mlpack.org", true, "./../data/"); + + BOOST_REQUIRE(Utils::PathExists("./../data/USCensus1990.csv")); + BOOST_REQUIRE(Utils::PathExists("./../data/USCensus1990_centroids.csv")); // Clean up. - Utils::RemoveFile("./../data/mnist_all.csv"); - Utils::RemoveFile("./../data/mnist_all_centroids.csv"); + Utils::RemoveFile("./../data/USCensus1990.csv"); + Utils::RemoveFile("./../data/USCensus1990_centroids.csv"); + Utils::RemoveFile("./../data/USCensus1990.tar.gz"); } BOOST_AUTO_TEST_SUITE_END(); diff --git a/utils/utils.hpp b/utils/utils.hpp index 6709d7dc..7ddb6d61 100644 --- a/utils/utils.hpp +++ b/utils/utils.hpp @@ -65,9 +65,8 @@ class Utils std::replace(pathForExtractionTemp.begin(), pathForExtractionTemp.end(), '/', '\\'); - command = "tar --force-local -xvzf " + - boost::filesystem::current_path().string() + "\\" + - pathToArchiveTemp; + command = "tar --force-local -xvzf " + pathToArchiveTemp + " -C " + + pathForExtractionTemp; #else command = command + boost::filesystem::current_path().string() + "/" + pathToArchive + " -C " + boost::filesystem::current_path().string() + @@ -292,9 +291,8 @@ class Utils } else { - mlpack::Log::Warn << "The " << path << "Doesn't exist." << std::endl; + mlpack::Log::Warn << "The " << path << " doesn't exist." << std::endl; } } }; - #endif