Skip to content

Commit

Permalink
Merge pull request #50 from rcurtin/grid_search_fixes
Browse files Browse the repository at this point in the history
Fix and include GridSearch by default
  • Loading branch information
rcurtin authored Nov 16, 2018
2 parents 21b8de7 + 6de21f4 commit 36eb6af
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 22 deletions.
3 changes: 1 addition & 2 deletions include/ensmallen.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@

#include "ensmallen_bits/problems/problems.hpp" // TODO: should move to another place

// TODO: remove mlpack bits from each of these files
#include "ensmallen_bits/ada_delta/ada_delta.hpp"
#include "ensmallen_bits/ada_grad/ada_grad.hpp"
#include "ensmallen_bits/adam/adam.hpp"
Expand All @@ -74,7 +73,7 @@

#include "ensmallen_bits/fw/frank_wolfe.hpp"
#include "ensmallen_bits/gradient_descent/gradient_descent.hpp"
// #include "ensmallen_bits/grid_search/grid_search.hpp"
#include "ensmallen_bits/grid_search/grid_search.hpp"
#include "ensmallen_bits/iqn/iqn.hpp"
#include "ensmallen_bits/katyusha/katyusha.hpp"
#include "ensmallen_bits/lbfgs/lbfgs.hpp"
Expand Down
22 changes: 2 additions & 20 deletions include/ensmallen_bits/grid_search/grid_search_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,31 +24,13 @@ double GridSearch::Optimize(
const std::vector<bool>& categoricalDimensions,
const arma::Row<size_t>& numCategories)
{
if (categoricalDimensions.size() != iterate.n_rows)
{
std::ostringstream oss;
oss << "GridSearch::Optimize(): expected information about "
<< iterate.n_rows << " dimensions in categoricalDimensions, "
<< "but got " << categoricalDimensions.size();
throw std::invalid_argument(oss.str());
}

if (numCategories.n_elem != iterate.n_rows)
{
std::ostringstream oss;
oss << "GridSearch::Optimize(): expected numCategories to have length "
<< "equal to number of dimensions (" << iterate.n_rows << ") but it has"
<< " length " << numCategories.n_elem;
throw std::invalid_argument(oss.str());
}

for (size_t i = 0; i < categoricalDimensions.size(); ++i)
{
if (categoricalDimensions[i])
if (!categoricalDimensions[i])
{
std::ostringstream oss;
oss << "GridSearch::Optimize(): the dimension " << i
<< "is not categorical" << std::endl;
<< " is not categorical" << std::endl;
throw std::invalid_argument(oss.str());
}
}
Expand Down
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ set(ENSMALLEN_TESTS_SOURCES
frankwolfe_test.cpp
function_test.cpp
gradient_descent_test.cpp
grid_search_test.cpp
iqn_test.cpp
katyusha_test.cpp
lbfgs_test.cpp
Expand Down
70 changes: 70 additions & 0 deletions tests/grid_search_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/**
* @file grid_search_test.cpp
* @author Ryan Curtin
*
* Test file for the GridSearch optimizer.
*
* ensmallen is free software; you may redistribute it and/or modify it under
* the terms of the 3-clause BSD license. You should have received a copy of
* the 3-clause BSD license along with ensmallen. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/

#include <ensmallen.hpp>
#include "catch.hpp"
#include "test_function_tools.hpp"

using namespace ens;
using namespace ens::test;

// An implementation of a simple categorical function. The parameters can be
// understood as x = [c1 c2 c3]. When c1 = 0, c2 = 2, and c3 = 1, the value of
// f(x) is 0. In any other case, the value of f(x) is 10. Therefore, the
// optimum is found at [0, 2, 1].
class SimpleCategoricalFunction
{
public:
// Return the objective function f(x) as described above.
double Evaluate(const arma::mat& x)
{
if (size_t(x[0]) == 0 &&
size_t(x[1]) == 2 &&
size_t(x[2]) == 1)
return 0.0;
else
return 10.0;
}
};

TEST_CASE("GridSearchTest", "[GridSearchTest]")
{
// Create and optimize the categorical function with the GridSearch
// optimizer. We must also create a std::vector<bool> that holds the types
// of each dimension, and an arma::Row<size_t> that holds the number of
// categories in each dimension.
SimpleCategoricalFunction c;

// We have three categorical dimensions only.
std::vector<bool> categoricalDimensions;
categoricalDimensions.push_back(true);
categoricalDimensions.push_back(true);
categoricalDimensions.push_back(true);

// The first category can take 5 values; the second can take 3; the third can
// take 12.
arma::Row<size_t> numCategories("5 3 12");

// The initial point for our optimization will be to set all categories to 0.
arma::mat params("0 0 0");

// Now create the GridSearch optimizer with default parameters, and run the
// optimization.
// The GridSearch type can be replaced with any ensmallen optimizer that
// is able to handle categorical functions.
GridSearch gs;
gs.Optimize(c, params, categoricalDimensions, numCategories);

REQUIRE(params[0] == 0);
REQUIRE(params[1] == 2);
REQUIRE(params[2] == 1);
}

0 comments on commit 36eb6af

Please sign in to comment.