diff --git a/tests/dataloader_tests.cpp b/tests/dataloader_tests.cpp index 8a6f322c..912cb1ab 100644 --- a/tests/dataloader_tests.cpp +++ b/tests/dataloader_tests.cpp @@ -76,7 +76,7 @@ BOOST_AUTO_TEST_CASE(MNISTDataLoaderTest) BOOST_REQUIRE_EQUAL(std::get<0>(dataloader.TrainSet()).n_cols, 8400); BOOST_REQUIRE_EQUAL(std::get<1>(dataloader.TrainSet()).n_rows, 1); BOOST_REQUIRE_EQUAL(std::get<0>(dataloader.ValidSet()).n_cols, 33600); - BOOST_REQUIRE_EQUAL(std::get<1>(dataloader.TrainSet()).n_rows, 1); + BOOST_REQUIRE_EQUAL(std::get<1>(dataloader.ValidSet()).n_rows, 1); // Clean up. Utils::RemoveFile("./../data/mnist-dataset/mnist_all.csv");