Skip to content

Commit

Permalink
Complete LoadObjectDetection Function
Browse files Browse the repository at this point in the history
  • Loading branch information
kartikdutt18 committed Jun 1, 2020
1 parent f0def62 commit 02cc8cd
Show file tree
Hide file tree
Showing 7 changed files with 175 additions and 41 deletions.
68 changes: 52 additions & 16 deletions augmentation/augmentation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@
* augmentation.Transform(dataloader.TrainFeatures);
* @endcode
*
* @tparam DatasetX Datatype on which augmentation will be done.
* @tparam DatasetType Datatype on which augmentation will be done.
*/
template<typename DatasetType = arma::mat>
class Augmentation
{
public:
Expand All @@ -47,28 +48,53 @@ class Augmentation
const double augmentationProbability);

/**
* Applies augmentation to the passed dataset.
*
* @param dataset Dataset on which augmentation will be applied.
* @param datapointWidth Width of a single data point i.e.
* Since each column represents a seperate data
* point.
* @param datapointHeight Height of a single data point.
* @param datapointDepth Depth of a single data point. For 2-dimensional
* data point, set it to 1. Defaults to 1.
*/
template<typename DatasetType = arma::mat>
void Transform(DatasetType& dataset);

template<typename DatasetType = arma::mat>
void ResizeTransform(DatasetType& dataset);

template <typename DatasetType = arma::mat>
void HorizontalFlipTransform(DatasetType &dataset);

template<typename DatasetType = arma::mat>
void VerticalFlipTransform(DatasetType& dataset);
void Transform(DatasetType& dataset,
const size_t datapointWidth,
const size_t datapointHeight,
const size_t datapointDepth = 1);

/**
* Applies resize transform to the entire dataset.
*
* @param dataset Dataset on which augmentation will be applied.
* @param datapointWidth Width of a single data point i.e.
* Since each column represents a seperate data
* point.
* @param datapointHeight Height of a single data point.
* @param datapointDepth Depth of a single data point. For 2-dimensional
* data point, set it to 1. Defaults to 1.
* @param augmentation String containing the transform.
*/
void ResizeTransform(DatasetType& dataset,
const size_t datapointWidth,
const size_t datapointHeight,
const size_t datapointDepth,
const std::string& augmentation);

private:
/**
* Initializes augmentation map for the class.
*/
void InitializeAugmentationMap();

/**
* Function to determine if augmentation has Resize function.
*/
bool HasResizeParam()
{
// Search in augmentation vector.
return augmentations.size() <= 0 ? false :
augmentations[0].find("resize") != std::string::npos ;
augmentations[0].find("resize") != std::string::npos;
}

/**
Expand All @@ -77,15 +103,16 @@ class Augmentation
* @param outWidth Output width of resized data point.
* @param outHeight Output height of resized data point.
*/
void GetResizeParam(size_t& outWidth, size_t& outHeight)
void GetResizeParam(size_t& outWidth,
size_t& outHeight)
{
if (!HasResizeParam())
{
return;
}

outWidth = -1;
outHeight = -1;
outWidth = 0;
outHeight = 0;

// Use regex to find one / two numbers. If only one provided
// set output width equal to output height.
Expand Down Expand Up @@ -121,6 +148,15 @@ class Augmentation

//! Locally held value of augmentation probability.
double augmentationProbability;

//! Locally help map for mapping functions and strings.
std::unordered_map<std::string, void(*)(DatasetType&,
size_t, size_t, size_t, std::string&)> augmentationMap;

// The dataloader class should have access to internal functions of
// the dataloader.
template<typename DatasetX, typename DatasetY, class ScalerType>
friend class DataLoader;
};

#include "augmentation_impl.hpp" // Include implementation.
Expand Down
70 changes: 65 additions & 5 deletions augmentation/augmentation_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,20 @@
#ifndef MODELS_AUGMENTATION_IMPL_HPP
#define MODELS_AUGMENTATION_IMPL_HPP

Augmentation::Augmentation() :
template<typename DatasetType>
Augmentation<DatasetType>::Augmentation() :
augmentations(std::vector<std::string>()),
augmentationProbability(0.2)
{
// Nothing to do here.
}

Augmentation::Augmentation(const std::vector<std::string>& augmentations,
const double augmentationProbability) :
augmentations(augmentations),
augmentationProbability(augmentationProbability)
template<typename DatasetType>
Augmentation<DatasetType>::Augmentation(
const std::vector<std::string>& augmentations,
const double augmentationProbability) :
augmentations(augmentations),
augmentationProbability(augmentationProbability)
{
// Sort the vector to place resize parameter to the front of the string.
// This prevents constant look ups for resize.
Expand All @@ -34,6 +37,63 @@ Augmentation::Augmentation(const std::vector<std::string>& augmentations,
{
return str1.find("resize") != std::string::npos;
});

// Fill augmentation map with supported augmentations other than resize.
InitializeAugmentationMap();
}

template<typename DatasetType>
void Augmentation<DatasetType>::Transform(DatasetType& dataset,
const size_t datapointWidth,
const size_t datapointHeight,
const size_t datapointDepth)
{
size_t i = 0;
if (this->HasResizeParam())
{
this->ResizeTransform(dataset);
i++;
}

for (; i < augmentations.size(); i++)
{
if (augmentationMap.count(augmentations[i]))
{
augmentationMap[augmentations[i]](dataset, datapointWidth,
datapointHeight, datapointDepth, augmentations[i]);
}
}
}

template<typename DatasetType>
void Augmentation<DatasetType>::ResizeTransform(
DatasetType& dataset,
const size_t datapointWidth,
const size_t datapointHeight,
const size_t datapointDepth,
const std::string& augmentation)
{
size_t outputWidth = 0, outputHeight = 0;

// Get output width and output height.
GetResizeParam(outputWidth, outputHeight);

// We will use mlpack's bilinear interpolation layer to
// resize the input.
mlpack::ann::BilinearInterpolation<DatasetType, DatasetType> resizeLayer(
datapointWidth, datapointHeight, outputWidth, outputHeight, datapointDepth);

// Not sure how to avoid a copy here.
DatasetType output;
resizeLayer.Forward(dataset, output);
dataset = output;
}

template<typename DatasetType>
void Augmentation<DatasetType>::InitializeAugmentationMap()
{
// Fill the map here.
}


#endif
4 changes: 4 additions & 0 deletions dataloader/dataloader.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <mlpack/core/math/shuffle_data.hpp>
#include <mlpack/core/data/split_data.hpp>
#include <boost/property_tree/ptree.hpp>
#include <augmentation/augmentation.hpp>
#include <dataloader/datasets.hpp>
#include <mlpack/prereqs.hpp>
#include <boost/foreach.hpp>
Expand Down Expand Up @@ -158,6 +159,9 @@ class DataLoader
void LoadObjectDetectionDataset(const std::string& pathToAnnotations,
const std::string& pathToImages,
const std::vector<std::string>& classes,
const std::vector<std::string>& augmentation =
std::vector<std::string>(),
const double augmentationProbability = 0.2,
const bool absolutePath = false,
const std::string& baseXMLTag = "annotation",
const std::string& imageNameXMLTag =
Expand Down
18 changes: 15 additions & 3 deletions dataloader/dataloader_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ template<
>::LoadObjectDetectionDataset(const std::string& pathToAnnotations,
const std::string& pathToImages,
const std::vector<std::string>& classes,
const std::vector<std::string>& augmentations,
const double augmentationProbability,
const bool absolutePath,
const std::string& baseXMLTag,
const std::string& imageNameXMLTag,
Expand All @@ -165,6 +167,8 @@ template<
const std::string& x2XMLTag,
const std::string& y2XMLTag)
{
Augmentation<DatasetX> augmentation(augmentations, augmentationProbability);

std::vector<boost::filesystem::path> annotationsDirectory, imagesDirectory;

// Fill the directory.
Expand Down Expand Up @@ -216,19 +220,23 @@ template<
size_t imageDepth = std::stoi(sizeInformation.get_child("depth").data());
mlpack::data::ImageInfo imageInfo(imageWidth, imageHeight, imageDepth);

// TODO: Resize the image here.

// Load the image.
// The image loaded here will be in column format i.e. Output will be matrix with the
// following shape {1, cols * rows * slices} in column major format.
DatasetX image;
mlpack::data::Load(pathToImages + imgName, image, imageInfo);

if (augmentation.HasResizeParam())
{
augmentation.ResizeTransform(image, imageWidth, imageHeight, imageDepth,
augmentation.augmentations[0]);
}

// Iterate over all object in annotation.
BOOST_FOREACH(boost::property_tree::ptree::value_type const& object,
annotation)
{
arma::uvec predictions(5);
arma::vec predictions(5);
// Iterate over property of the object to get class label and
// bounding box coordinates.
if (object.first == objectXMLTag)
Expand All @@ -249,6 +257,10 @@ template<
std::stoi(coordinate.second.data());
}
}

// Add object to training set.
trainFeatures.insert_cols(0, image);
trainLabels.insert_cols(0, predictions);
}
}
}
Expand Down
32 changes: 26 additions & 6 deletions tests/dataloader_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,32 @@ BOOST_AUTO_TEST_CASE(CSVDataLoaderTest)
*/
BOOST_AUTO_TEST_CASE(MNISTDataLoaderTest)
{
/**
/*
DataLoader<> dataloader("mnist", true, 0.80);
// Check for correct dimensions.
BOOST_REQUIRE_EQUAL(dataloader.TrainFeatures().n_rows, 784);
BOOST_REQUIRE_EQUAL(dataloader.TestFeatures().n_rows, 784);
BOOST_REQUIRE_EQUAL(dataloader.ValidFeatures().n_rows, 784);
// Check for correct dimensions.
BOOST_REQUIRE_EQUAL(dataloader.TrainFeatures().n_cols, 784);
BOOST_REQUIRE_EQUAL(dataloader.TestFeatures().n_cols, 784);
BOOST_REQUIRE_EQUAL(dataloader.ValidFeatures().n_cols, 784);
BOOST_REQUIRE_EQUAL(dataloader.TrainFeatures().n_rows, 33600);
BOOST_REQUIRE_EQUAL(dataloader.TrainFeatures().n_cols, 8400);
BOOST_REQUIRE_EQUAL(dataloader.ValidFeatures().n_cols, 33600);
BOOST_REQUIRE_EQUAL(dataloader.TestFeatures().n_cols, 28000);
// Check if we can access both features and labels using
// TrainSet tuple and ValidSet tuple.
BOOST_REQUIRE_EQUAL(std::get<0>(dataloader.TrainSet()).n_cols, 8400);
BOOST_REQUIRE_EQUAL(std::get<1>(dataloader.TrainSet()).n_rows, 1);
BOOST_REQUIRE_EQUAL(std::get<0>(dataloader.ValidSet()).n_cols, 33600);
BOOST_REQUIRE_EQUAL(std::get<1>(dataloader.ValidSet()).n_rows, 1);
// Clean up.
Utils::RemoveFile("./../data/mnist-dataset/mnist_all.csv");
Utils::RemoveFile("./../data/mnist-dataset/mnist_all_centroids.csv");
Utils::RemoveFile("./../data/mnist-dataset/mnist_train.csv");
Utils::RemoveFile("./../data/mnist-dataset/mnist_test.csv");
Utils::RemoveFile("./../data/mnist.tar.gz");
*/
}

Expand All @@ -76,7 +95,8 @@ BOOST_AUTO_TEST_CASE(ObjectDetectionDataLoader)
{
DataLoader<> dataloader;
dataloader.LoadObjectDetectionDataset("./../data/annotations/",
"./../data/images/", {"person", "foot", "aeroplane", "head", "hand"});
"./../data/images/", {"person", "foot", "aeroplane", "head", "hand"},
{"resize 64, 64"});
}

BOOST_AUTO_TEST_SUITE_END();
16 changes: 10 additions & 6 deletions tests/utils_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,19 @@ BOOST_AUTO_TEST_CASE(RemoveFileTest)

BOOST_AUTO_TEST_CASE(ExtractFilesTest)
{
Utils::DownloadFile("/datasets/mnist.tar.gz", "./../data/mnist.tar.gz", "",
false, true, "www.mlpack.org", true, "./../data/");
std::vector<boost::filesystem::path> vec;

BOOST_REQUIRE(Utils::PathExists("./../data/mnist_all.csv"));
BOOST_REQUIRE(Utils::PathExists("./../data/mnist.tar.gz"));
Utils::DownloadFile("/datasets/USCensus1990.tar.gz",
"./../data/USCensus1990.tar.gz", "", false, true,
"www.mlpack.org", true, "./../data/");

BOOST_REQUIRE(Utils::PathExists("./../data/USCensus1990.csv"));
BOOST_REQUIRE(Utils::PathExists("./../data/USCensus1990_centroids.csv"));

// Clean up.
Utils::RemoveFile("./../data/mnist_all.csv");
Utils::RemoveFile("./../data/mnist_all_centroids.csv");
Utils::RemoveFile("./../data/USCensus1990.csv");
Utils::RemoveFile("./../data/USCensus1990_centroids.csv");
Utils::RemoveFile("./../data/USCensus1990.tar.gz");
}

BOOST_AUTO_TEST_SUITE_END();
8 changes: 3 additions & 5 deletions utils/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,8 @@ class Utils
std::replace(pathForExtractionTemp.begin(), pathForExtractionTemp.end(),
'/', '\\');

command = "tar --force-local -xvzf " +
boost::filesystem::current_path().string() + "\\" +
pathToArchiveTemp;
command = "tar --force-local -xvzf " + pathToArchiveTemp + " -C " +
pathForExtractionTemp;
#else
command = command + boost::filesystem::current_path().string() + "/" +
pathToArchive + " -C " + boost::filesystem::current_path().string() +
Expand Down Expand Up @@ -292,9 +291,8 @@ class Utils
}
else
{
mlpack::Log::Warn << "The " << path << "Doesn't exist." << std::endl;
mlpack::Log::Warn << "The " << path << " doesn't exist." << std::endl;
}
}
};

#endif

0 comments on commit 02cc8cd

Please sign in to comment.