From 93acea2056eb4333e3c1a0a883a4b9be4333e649 Mon Sep 17 00:00:00 2001 From: kartikdutt18 Date: Thu, 4 Jun 2020 10:01:22 +0530 Subject: [PATCH] Build Fixed, Nice way to avoid hidden files --- dataloader/dataloader.hpp | 6 ++++-- dataloader/dataloader_impl.hpp | 29 ++++++++++++++++++----------- tests/dataloader_tests.cpp | 3 ++- utils/utils.hpp | 7 +++++++ 4 files changed, 31 insertions(+), 14 deletions(-) diff --git a/dataloader/dataloader.hpp b/dataloader/dataloader.hpp index b0cae9ad..b8a40977 100644 --- a/dataloader/dataloader.hpp +++ b/dataloader/dataloader.hpp @@ -137,9 +137,10 @@ class DataLoader * * @param pathToAnnotations Path to the folder containing XML type annotation files. * @param pathToImages Path to folder containing images corresponding to annotations. - * @param validRatio Ratio of dataset that will be used for validation. * @param classes Vector of strings containing list of classes. Labels are assigned * according to this vector. + * @param validRatio Ratio of dataset that will be used for validation. + * @param shuffle Boolean to determine whether the dataset is shuffled. * @param augmentation Vector strings of augmentations supported by mlpack. * @param augmentationProbability Probability of applying augmentation to a particular cell. * @param absolutePath Boolean to determine if absolute path is used. Defaults to false. @@ -161,8 +162,9 @@ class DataLoader */ void LoadObjectDetectionDataset(const std::string& pathToAnnotations, const std::string& pathToImages, - const double validRatio, const std::vector& classes, + const double validRatio = 0.2, + const bool shuffle = true, const std::vector& augmentation = std::vector(), const double augmentationProbability = 0.2, diff --git a/dataloader/dataloader_impl.hpp b/dataloader/dataloader_impl.hpp index 9f1e020a..dd40a6ff 100644 --- a/dataloader/dataloader_impl.hpp +++ b/dataloader/dataloader_impl.hpp @@ -72,8 +72,8 @@ template< } LoadObjectDetectionDataset(datasetMap[dataset].trainingAnnotationPath, - datasetMap[dataset].trainingImagesPath, validRatio, - datasetMap[dataset].classes, augmentations, augmentationProbability); + datasetMap[dataset].trainingImagesPath, datasetMap[dataset].classes, + validRatio, shuffle,augmentations, augmentationProbability); // Load testing data if any. Most object detection dataset // have private evaluation servers. @@ -180,8 +180,9 @@ template< DatasetX, DatasetY, ScalerType >::LoadObjectDetectionDataset(const std::string& pathToAnnotations, const std::string& pathToImages, - const double validRatio, const std::vector& classes, + const double validRatio, + const bool shuffle, const std::vector& augmentations, const double augmentationProbability, const bool absolutePath, @@ -212,7 +213,6 @@ template< for (size_t i = 0; i < classes.size(); i++) classMap.insert(std::make_pair(classes[i], i)); - // Map to insert values in a column vector. std::unordered_map indexMap; indexMap.insert(std::make_pair(classNameXMLTag, 0)); @@ -223,13 +223,14 @@ template< // Keep track of files loaded. size_t totalFiles = annotationsDirectory.size(), loadedFiles = 0; + size_t imageWidth = 0, imageHeight = 0, imageDepth = 0; // Read the XML file. for (boost::filesystem::path annotationFile : annotationsDirectory) { if (annotationFile.string().length() <= 3 || annotationFile.string().substr( - annotationFile.string().length() - 3) != "xml") + annotationFile.string().length() - 3) != "xml") { continue; } @@ -260,9 +261,9 @@ template< // Get the size of image to create image info required // by mlpack::data::Load function. boost::property_tree::ptree sizeInfo = annotation.get_child(sizeXMLTag); - size_t imageWidth = std::stoi(sizeInfo.get_child("width").data()); - size_t imageHeight = std::stoi(sizeInfo.get_child("height").data()); - size_t imageDepth = std::stoi(sizeInfo.get_child("depth").data()); + imageWidth = std::stoi(sizeInfo.get_child("width").data()); + imageHeight = std::stoi(sizeInfo.get_child("height").data()); + imageDepth = std::stoi(sizeInfo.get_child("depth").data()); mlpack::data::ImageInfo imageInfo(imageWidth, imageHeight, imageDepth); // Load the image. @@ -276,6 +277,8 @@ template< { augmentation.ResizeTransform(image, imageWidth, imageHeight, imageDepth, augmentation.augmentations[0]); + augmentation.GetResizeParam(imageWidth, imageHeight, + augmentation.augmentations[0]); } // Iterate over all object in annotation. @@ -312,9 +315,13 @@ template< } } - // Add data split and augmentation here. - trainFeatures = dataset; - trainLabels = labels; + // Add data split here. + trainFeatures = std::move(dataset); + trainLabels = std::move(labels); + + // Augment the training data. + augmentation.Transform(trainFeatures, imageWidth, imageHeight, + imageDepth); } template< diff --git a/tests/dataloader_tests.cpp b/tests/dataloader_tests.cpp index 35ec8cad..0ca4a803 100644 --- a/tests/dataloader_tests.cpp +++ b/tests/dataloader_tests.cpp @@ -99,6 +99,7 @@ BOOST_AUTO_TEST_CASE(ObjectDetectionDataLoader) std::string annotaionPath = "Annotations/"; std::string imagesPath = "Images/"; double validRatio = 0.2; + bool shuffle = true; // Classes in the dataset. std::vector classes = {"background", "aeroplane", "bicycle", @@ -109,7 +110,7 @@ BOOST_AUTO_TEST_CASE(ObjectDetectionDataLoader) // Resize the image to 64 x 64. std::vector augmentation = {"resize (64, 64)"}; dataloader.LoadObjectDetectionDataset(basePath + annotaionPath, - basePath + imagesPath, validRatio, classes, augmentation); + basePath + imagesPath, classes, validRatio, shuffle, augmentation); // There are total 15 objects in images. BOOST_REQUIRE_EQUAL(dataloader.TrainLabels().n_cols, 15); diff --git a/utils/utils.hpp b/utils/utils.hpp index 161125df..ce91aff2 100644 --- a/utils/utils.hpp +++ b/utils/utils.hpp @@ -315,6 +315,13 @@ class Utils boost::filesystem::directory_iterator(), std::back_inserter(pathVector)); + // Remove hidden files. + pathVector.erase(std::remove_if(pathVector.begin(), pathVector.end(), + [](boost::filesystem::path curPath) + { + return curPath.filename().string()[0] == '.'; + }), pathVector.end()); + // Sort the path vector. std::sort(pathVector.begin(), pathVector.end()); }