Skip to content

Commit

Permalink
Add tests for dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
kartikdutt18 committed Jun 1, 2020
1 parent 7b3c433 commit 015e197
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 25 deletions.
21 changes: 0 additions & 21 deletions data/pascal-voc-classes.txt

This file was deleted.

6 changes: 6 additions & 0 deletions dataloader/dataloader_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,12 @@ template<
// Read the xml file.
for (boost::filesystem::path annotationFile : annotationsDirectory)
{
if (annotationFile.string().length() <= 3 ||
annotationFile.string().substr(
annotationFile.string().length() - 3) != "xml")
{
continue;
}
// Read the xml file.
boost::property_tree::ptree xmlFile;
boost::property_tree::read_xml(annotationFile.string(), xmlFile);
Expand Down
3 changes: 2 additions & 1 deletion tests/augmentation_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ BOOST_AUTO_TEST_SUITE(AugmentationTest);

BOOST_AUTO_TEST_CASE(ResizeAugmentationTest)
{
Augmentation<> augmentation(std::vector<std::string>(1, "resize (5, 4)"), 0.2);
Augmentation<> augmentation(std::vector<std::string>(1,
"resize (5, 4)"), 0.2);

// Test on a square matrix.
arma::mat input;
Expand Down
30 changes: 27 additions & 3 deletions tests/dataloader_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,33 @@ BOOST_AUTO_TEST_CASE(MNISTDataLoaderTest)
BOOST_AUTO_TEST_CASE(ObjectDetectionDataLoader)
{
DataLoader<> dataloader;
dataloader.LoadObjectDetectionDataset("./../data/annotations/",
"./../data/images/", {"person", "foot", "aeroplane", "head", "hand"},
{"resize 64, 64"});
Utils::ExtractFiles("./../data/PASCAL-VOC-Test.zip", "./../data/");

// Set paths for dataset.
std::string basePath = "./../data/PASCAL-VOC-Test/";
std::string annotaionPath = "Annotations/";
std::string imagesPath = "Images/";

// Classes in the dataset.
std::vector<std::string> classes = {"background", "aeroplane", "bicycle",
"bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow",
"diningtable", "dog", "horse", "motorbike", "person", "pottedplant",
"sheep", "sofa", "train", "tvmonitor"};

// Resize the image to 64 x 64.
std::vector<std::string> augmentation = {"resize (64, 64)"};
dataloader.LoadObjectDetectionDataset(basePath + annotaionPath,
basePath + imagesPath, classes, augmentation);

// There are total 15 objects in images.
BOOST_REQUIRE_EQUAL(dataloader.TrainLabels().n_cols, 15);
// They correspond to class name, x1, y1, x2, y2.
BOOST_REQUIRE_EQUAL(dataloader.TrainLabels().n_rows, 5);

// Rows will be equal to shape image depth * image width * image height.
BOOST_REQUIRE_EQUAL(dataloader.TrainFeatures().n_rows, 64 * 64 * 3);
// There are total 15 objects in the images.
BOOST_REQUIRE_EQUAL(dataloader.TrainFeatures().n_cols, 15);
}

BOOST_AUTO_TEST_SUITE_END();

0 comments on commit 015e197

Please sign in to comment.