diff --git a/dataloader/dataloader.hpp b/dataloader/dataloader.hpp index 4b3249b2..5e707ab4 100644 --- a/dataloader/dataloader.hpp +++ b/dataloader/dataloader.hpp @@ -131,16 +131,44 @@ class DataLoader * 4. Each object tag should contain name tag i.e. class of the object. * 5. Each object tag should contain bndbox tag containing xmin, ymin, xmax, ymax. * - * NOTE : Labels are assigned using lexicographically. Set verbose to 1 to print labels + * NOTE : Labels are assigned using classes vector. Set verbose to 1 to print labels * and their corresponding class. * * @param pathToAnnotations Path to the folder containg xml type annotation files. * @param pathToImages Path to folder containing images corresponding to annotations. + * @param classes Vector of strings containing list of classes. Labels are assigned + * according to this vector. * @param absolutePath Boolean to determine if absolute path is used. Defaults to false. + * @param baseXMLTag XML tag name which wraps around the annotation file. + * @param imageNameXMLTag XML tag name which holds the value of image filename. + * @param objectXMLTag XML tag name which holds details of bounding box i.e. class and + * coordinates of bounding box. + * @param bndboxXMLTag XML tag name which holds coordinates of bounding box. + * @param classNameXMLTag XML tag name inside objectXMLTag which holds the name of the + * class of bounding box. + * @param x1XMLTag XML tag name inside bndboxXMLTag which hold value of lower most + * x coordinate of bounding box. + * @param y1XMLTag XML tag name inside bndboxXMLTag which hold value of lower most + * y coordinate of bounding box. + * @param x2XMLTag XML tag name inside bndboxXMLTag which hold value of upper most + * x coordinate of bounding box. + * @param y2XMLTag XML tag name inside bndboxXMLTag which hold value of upper most + * y coordinate of bounding box. */ void LoadObjectDetectionDataset(const std::string& pathToAnnotations, const std::string& pathToImages, - const bool absolutePath = false); + const std::vector& classes, + const bool absolutePath = false, + const std::string& baseXMLTag = "annotation", + const std::string& imageNameXMLTag = + "filename", + const std::string& objectXMLTag = "object", + const std::string& bndboxXMLTag = "bndbox", + const std::string& classNameXMLTag = "name", + const std::string& x1XMLTag = "xmin", + const std::string& y1XMLTag = "ymin", + const std::string& x2XMLTag = "xmax", + const std::string& y2XMLTag = "ymax"); //! Get the training dataset features. DatasetX TrainFeatures() const { return trainFeatures; } diff --git a/dataloader/dataloader_impl.hpp b/dataloader/dataloader_impl.hpp index e423fc34..ccde5602 100644 --- a/dataloader/dataloader_impl.hpp +++ b/dataloader/dataloader_impl.hpp @@ -152,7 +152,17 @@ template< DatasetX, DatasetY, ScalerType >::LoadObjectDetectionDataset(const std::string& pathToAnnotations, const std::string& pathToImages, - const bool absolutePath) + const std::vector& classes, + const bool absolutePath, + const std::string& baseXMLTag, + const std::string& imageNameXMLTag, + const std::string& objectXMLTag, + const std::string& bndboxXMLTag, + const std::string& classNameXMLTag, + const std::string& x1XMLTag, + const std::string& y1XMLTag, + const std::string& x2XMLTag, + const std::string& y2XMLTag) { std::vector annotationsDirectory, imagesDirectory; @@ -160,19 +170,59 @@ template< Utils::ListDir(pathToAnnotations, annotationsDirectory, absolutePath); Utils::ListDir(pathToImages, imagesDirectory, absolutePath); + // Create a map for labels and corresponding class name. + // This provides faster access to class labels. + std::unordered_map classMap; + for (size_t i = 0; i < classes.size(); i++) + { + classMap.insert(std::make_pair(classes[i], i)); + } + // Read the xml file. for (boost::filesystem::path annotationFile : annotationsDirectory) { // Read the xml file. boost::property_tree::ptree annotation; - std::cout << annotationFile.string() << std::endl; boost::property_tree::read_xml(annotationFile.string(), annotation); + // Map to insert values in a column vector. + std::unordered_map indexMap; + indexMap.insert(std::make_pair(classNameXMLTag, 0)); + indexMap.insert(std::make_pair(x1XMLTag, 1)); + indexMap.insert(std::make_pair(y1XMLTag, 2)); + indexMap.insert(std::make_pair(x2XMLTag, 3)); + indexMap.insert(std::make_pair(y2XMLTag, 4)); + // Read properties inside annotation file. - BOOST_FOREACH (boost::property_tree::ptree::value_type const& object, - annotation.get_child("annotation.object")) + BOOST_FOREACH(boost::property_tree::ptree::value_type const& object, + annotation.get_child(baseXMLTag)) { - std::cout << object.first << std::endl; + // Column vector to temporarily store details of bounding box. + if (object.first == objectXMLTag) + { + arma::uvec predictions(5); + + // Iterate over property of the object to get class label and + // bounding box coordinates. + if (classMap.count(object.second.get_child(classNameXMLTag).data())) + { + predictions(indexMap[classNameXMLTag]) = classMap[ + object.second.get_child(classNameXMLTag).data()]; + boost::property_tree::ptree const &boundingBox = + object.second.get_child(bndboxXMLTag); + + BOOST_FOREACH(boost::property_tree::ptree::value_type + const& coordinate, boundingBox) + { + if (indexMap.count(coordinate.first)) + { + predictions(indexMap[coordinate.first]) = + std::stoi(coordinate.second.data()); + } + } + // predictions.print(); + } + } } } } diff --git a/tests/dataloader_tests.cpp b/tests/dataloader_tests.cpp index 86c06630..01859f45 100644 --- a/tests/dataloader_tests.cpp +++ b/tests/dataloader_tests.cpp @@ -75,7 +75,8 @@ BOOST_AUTO_TEST_CASE(MNISTDataLoaderTest) BOOST_AUTO_TEST_CASE(ObjectDetectionDataLoader) { DataLoader<> dataloader; - dataloader.LoadObjectDetectionDataset("./../data/annotations/", "./../data"); + dataloader.LoadObjectDetectionDataset("./../data/annotations/", "./../data", + {"person", "foot", "aeroplane", "head", "hand"}); } BOOST_AUTO_TEST_SUITE_END();