diff --git a/dataloader/dataloader.hpp b/dataloader/dataloader.hpp index 37aea9c8..fe7b0b2c 100644 --- a/dataloader/dataloader.hpp +++ b/dataloader/dataloader.hpp @@ -63,14 +63,14 @@ class DataLoader * * @param datasetPath Path or name of dataset. * @param shuffle whether or not to shuffle the data. - * @param ratio Ratio for train-test split. + * @param testRatio Ratio of dataset to be used for validation set. * @param useScaler Use feature scaler for pre-processing the dataset. * @param augmentation Adds augmentation to training data only. * @param augmentationProbability Probability of applying augmentation on dataset. */ DataLoader(const std::string& dataset, const bool shuffle, - const double ratio = 0.75, + const double testRatio = 0.25, const bool useScaler = true, const std::vector augmentation = std::vector(), @@ -85,7 +85,7 @@ class DataLoader * Note: This option augmentation to NULL, set ratio to 1 and * scaler will be used to only transform the test data. * @param shuffle Boolean to determine whether or not to shuffle the data. - * @param ratio Ratio for train-test split. + * @param testRatio Ratio of dataset to be used for validation set. * @param useScaler Fits the scaler on training data and transforms dataset. * @param dropHeader Drops the first row from CSV. * @param startInputFeatures First Index which will be fed into the model as input. @@ -106,7 +106,7 @@ class DataLoader void LoadCSV(const std::string& datasetPath, const bool loadTrainData = true, const bool shuffle = true, - const double ratio = 0.75, + const double testRatio = 0.25, const bool useScaler = false, const bool dropHeader = false, const int startInputFeatures = -1, diff --git a/dataloader/dataloader_impl.hpp b/dataloader/dataloader_impl.hpp index 75b841c2..b514fee9 100644 --- a/dataloader/dataloader_impl.hpp +++ b/dataloader/dataloader_impl.hpp @@ -36,7 +36,7 @@ template< DatasetX, DatasetY, ScalerType >::DataLoader(const std::string& dataset, const bool shuffle, - const double ratio, + const double testRatio, const bool useScaler, const std::vector augmentation, const double augmentationProbability) @@ -49,14 +49,14 @@ template< if (datasetMap[dataset].loadCSV) { - LoadCSV(datasetMap[dataset].trainPath, true, shuffle, ratio, useScaler, - datasetMap[dataset].dropHeader, + LoadCSV(datasetMap[dataset].trainPath, true, shuffle, testRatio, + useScaler, datasetMap[dataset].dropHeader, datasetMap[dataset].startTrainingInputFeatures, datasetMap[dataset].endTrainingInputFeatures, datasetMap[dataset].endTrainingPredictionFeatures, datasetMap[dataset].endTrainingPredictionFeatures); - LoadCSV(datasetMap[dataset].testPath, false, false, ratio, useScaler, + LoadCSV(datasetMap[dataset].testPath, false, false, testRatio, useScaler, datasetMap[dataset].dropHeader, datasetMap[dataset].startTestingInputFeatures, datasetMap[dataset].endTestingInputFeatures); @@ -85,7 +85,7 @@ template< >::LoadCSV(const std::string& datasetPath, const bool loadTrainData, const bool shuffle, - const double ratio, + const double testRatio, const bool useScaler, const bool dropHeader, const int startInputFeatures, @@ -104,7 +104,7 @@ template< if (loadTrainData) { arma::mat trainDataset, validDataset; - data::Split(dataset, trainDataset, validDataset, ratio, shuffle); + data::Split(dataset, trainDataset, validDataset, testRatio, shuffle); trainFeatures = trainDataset.rows(WrapIndex(startInputFeatures, trainDataset.n_rows), WrapIndex(endInputFeatures, diff --git a/tests/dataloader_tests.cpp b/tests/dataloader_tests.cpp index 00c35c5b..8a6f322c 100644 --- a/tests/dataloader_tests.cpp +++ b/tests/dataloader_tests.cpp @@ -66,11 +66,13 @@ BOOST_AUTO_TEST_CASE(MNISTDataLoaderTest) BOOST_REQUIRE_EQUAL(dataloader.TestFeatures().n_rows, 784); BOOST_REQUIRE_EQUAL(dataloader.ValidFeatures().n_rows, 784); - + // Check for correct dimensions. BOOST_REQUIRE_EQUAL(dataloader.TrainFeatures().n_cols, 8400); BOOST_REQUIRE_EQUAL(dataloader.ValidFeatures().n_cols, 33600); BOOST_REQUIRE_EQUAL(dataloader.TestFeatures().n_cols, 28000); + // Check if we can access both features and labels using + // TrainSet tuple and ValidSet tuple. BOOST_REQUIRE_EQUAL(std::get<0>(dataloader.TrainSet()).n_cols, 8400); BOOST_REQUIRE_EQUAL(std::get<1>(dataloader.TrainSet()).n_rows, 1); BOOST_REQUIRE_EQUAL(std::get<0>(dataloader.ValidSet()).n_cols, 33600);