Skip to content

Commit

Permalink
Change variable to test ratio and add more comments in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kartikdutt18 committed May 31, 2020
1 parent 6ae432e commit 040decb
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 11 deletions.
8 changes: 4 additions & 4 deletions dataloader/dataloader.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,14 @@ class DataLoader
*
* @param datasetPath Path or name of dataset.
* @param shuffle whether or not to shuffle the data.
* @param ratio Ratio for train-test split.
* @param testRatio Ratio of dataset to be used for validation set.
* @param useScaler Use feature scaler for pre-processing the dataset.
* @param augmentation Adds augmentation to training data only.
* @param augmentationProbability Probability of applying augmentation on dataset.
*/
DataLoader(const std::string& dataset,
const bool shuffle,
const double ratio = 0.75,
const double testRatio = 0.25,
const bool useScaler = true,
const std::vector<std::string> augmentation =
std::vector<std::string>(),
Expand All @@ -85,7 +85,7 @@ class DataLoader
* Note: This option augmentation to NULL, set ratio to 1 and
* scaler will be used to only transform the test data.
* @param shuffle Boolean to determine whether or not to shuffle the data.
* @param ratio Ratio for train-test split.
* @param testRatio Ratio of dataset to be used for validation set.
* @param useScaler Fits the scaler on training data and transforms dataset.
* @param dropHeader Drops the first row from CSV.
* @param startInputFeatures First Index which will be fed into the model as input.
Expand All @@ -106,7 +106,7 @@ class DataLoader
void LoadCSV(const std::string& datasetPath,
const bool loadTrainData = true,
const bool shuffle = true,
const double ratio = 0.75,
const double testRatio = 0.25,
const bool useScaler = false,
const bool dropHeader = false,
const int startInputFeatures = -1,
Expand Down
12 changes: 6 additions & 6 deletions dataloader/dataloader_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ template<
DatasetX, DatasetY, ScalerType
>::DataLoader(const std::string& dataset,
const bool shuffle,
const double ratio,
const double testRatio,
const bool useScaler,
const std::vector<std::string> augmentation,
const double augmentationProbability)
Expand All @@ -49,14 +49,14 @@ template<

if (datasetMap[dataset].loadCSV)
{
LoadCSV(datasetMap[dataset].trainPath, true, shuffle, ratio, useScaler,
datasetMap[dataset].dropHeader,
LoadCSV(datasetMap[dataset].trainPath, true, shuffle, testRatio,
useScaler, datasetMap[dataset].dropHeader,
datasetMap[dataset].startTrainingInputFeatures,
datasetMap[dataset].endTrainingInputFeatures,
datasetMap[dataset].endTrainingPredictionFeatures,
datasetMap[dataset].endTrainingPredictionFeatures);

LoadCSV(datasetMap[dataset].testPath, false, false, ratio, useScaler,
LoadCSV(datasetMap[dataset].testPath, false, false, testRatio, useScaler,
datasetMap[dataset].dropHeader,
datasetMap[dataset].startTestingInputFeatures,
datasetMap[dataset].endTestingInputFeatures);
Expand Down Expand Up @@ -85,7 +85,7 @@ template<
>::LoadCSV(const std::string& datasetPath,
const bool loadTrainData,
const bool shuffle,
const double ratio,
const double testRatio,
const bool useScaler,
const bool dropHeader,
const int startInputFeatures,
Expand All @@ -104,7 +104,7 @@ template<
if (loadTrainData)
{
arma::mat trainDataset, validDataset;
data::Split(dataset, trainDataset, validDataset, ratio, shuffle);
data::Split(dataset, trainDataset, validDataset, testRatio, shuffle);

trainFeatures = trainDataset.rows(WrapIndex(startInputFeatures,
trainDataset.n_rows), WrapIndex(endInputFeatures,
Expand Down
4 changes: 3 additions & 1 deletion tests/dataloader_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,13 @@ BOOST_AUTO_TEST_CASE(MNISTDataLoaderTest)
BOOST_REQUIRE_EQUAL(dataloader.TestFeatures().n_rows, 784);
BOOST_REQUIRE_EQUAL(dataloader.ValidFeatures().n_rows, 784);


// Check for correct dimensions.
BOOST_REQUIRE_EQUAL(dataloader.TrainFeatures().n_cols, 8400);
BOOST_REQUIRE_EQUAL(dataloader.ValidFeatures().n_cols, 33600);
BOOST_REQUIRE_EQUAL(dataloader.TestFeatures().n_cols, 28000);

// Check if we can access both features and labels using
// TrainSet tuple and ValidSet tuple.
BOOST_REQUIRE_EQUAL(std::get<0>(dataloader.TrainSet()).n_cols, 8400);
BOOST_REQUIRE_EQUAL(std::get<1>(dataloader.TrainSet()).n_rows, 1);
BOOST_REQUIRE_EQUAL(std::get<0>(dataloader.ValidSet()).n_cols, 33600);
Expand Down

0 comments on commit 040decb

Please sign in to comment.