Skip to content

Commit

Permalink
Merge pull request #384 from rcurtin/callbacks-use-results
Browse files Browse the repository at this point in the history
Fix some callbacks that ignored return values
  • Loading branch information
rcurtin authored Nov 26, 2023
2 parents cbc90d8 + d361b86 commit 7ade3e9
Show file tree
Hide file tree
Showing 12 changed files with 360 additions and 203 deletions.
8 changes: 4 additions & 4 deletions .appveyor.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@ environment:
BLAS_LIBRARY_DLL: "%APPVEYOR_BUILD_FOLDER%/OpenBLAS.0.2.14.1/lib/native/lib/x64/libopenblas.dll"

matrix:
- APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2015
VSVER: Visual Studio 14 2015 Win64
MSBUILD: C:\Program Files (x86)\MSBuild\14.0\bin\MSBuild.exe

- APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2017
VSVER: Visual Studio 15 2017 Win64
MSBUILD: C:\Program Files (x86)\Microsoft Visual Studio\2017\Community\MSBuild\15.0\Bin\MSBuild.exe
Expand All @@ -16,6 +12,10 @@ environment:
VSVER: Visual Studio 16 2019
MSBUILD: C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\MSBuild\Current\Bin\MSBuild.exe

- APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2022
VSVER: Visual Studio 17 2022
MSBUILD: C:\Program Files\Microsoft Visual Studio\2022\Community\MSBuild\Current\Bin\MSBuild.exe

configuration: Release

install:
Expand Down
227 changes: 152 additions & 75 deletions include/ensmallen_bits/callbacks/callbacks.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,7 @@ class Callback
typename MatType>
static typename std::enable_if<
callbacks::traits::HasBeginOptimizationSignature<
// Check for boolean return values anyway, for older ensmallen callbacks.
// (The return value is ignored.)
CallbackType, OptimizerType, FunctionType, MatType>::hasBool,
CallbackType, OptimizerType, FunctionType, MatType>::value,
void>::type
BeginOptimizationFunction(CallbackType& callback,
OptimizerType& optimizer,
Expand All @@ -85,25 +83,8 @@ class Callback
typename FunctionType,
typename MatType>
static typename std::enable_if<
callbacks::traits::HasBeginOptimizationSignature<
CallbackType, OptimizerType, FunctionType, MatType>::hasVoid,
void>::type
BeginOptimizationFunction(CallbackType& callback,
OptimizerType& optimizer,
FunctionType& function,
MatType& coordinates)
{
const_cast<CallbackType&>(callback).BeginOptimization(optimizer, function,
coordinates);
}

template<typename CallbackType,
typename OptimizerType,
typename FunctionType,
typename MatType>
static typename std::enable_if<
callbacks::traits::HasBeginOptimizationSignature<
CallbackType, OptimizerType, FunctionType, MatType>::hasNone,
!callbacks::traits::HasBeginOptimizationSignature<
CallbackType, OptimizerType, FunctionType, MatType>::value,
void>::type
BeginOptimizationFunction(CallbackType& /* callback */,
OptimizerType& /* optimizer */,
Expand Down Expand Up @@ -175,7 +156,8 @@ class Callback
typename OptimizerType,
typename FunctionType,
typename MatType>
static typename std::enable_if<!callbacks::traits::HasEndOptimizationSignature<
static typename std::enable_if<
!callbacks::traits::HasEndOptimizationSignature<
CallbackType, OptimizerType, FunctionType, MatType>::value,
void>::type
EndOptimizationFunction(CallbackType& /* callback */,
Expand Down Expand Up @@ -234,24 +216,42 @@ class Callback
typename FunctionType,
typename MatType>
static typename std::enable_if<callbacks::traits::HasEvaluateSignature<
CallbackType, OptimizerType, FunctionType, MatType>::value,
CallbackType, OptimizerType, FunctionType, MatType>::hasBool,
bool>::type
EvaluateFunction(CallbackType& callback,
OptimizerType& optimizer,
FunctionType& function,
const MatType& coordinates,
const double objective)
{
return (const_cast<CallbackType&>(callback).Evaluate(
optimizer, function, coordinates, objective), false);
return const_cast<CallbackType&>(callback).Evaluate(optimizer, function,
coordinates, objective);
}

template<typename CallbackType,
typename OptimizerType,
typename FunctionType,
typename MatType>
static typename std::enable_if<!callbacks::traits::HasEvaluateSignature<
CallbackType, OptimizerType, FunctionType, MatType>::value,
static typename std::enable_if<callbacks::traits::HasEvaluateSignature<
CallbackType, OptimizerType, FunctionType, MatType>::hasVoid,
bool>::type
EvaluateFunction(CallbackType& callback,
OptimizerType& optimizer,
FunctionType& function,
const MatType& coordinates,
const double objective)
{
const_cast<CallbackType&>(callback).Evaluate(optimizer, function,
coordinates, objective);
return false;
}

template<typename CallbackType,
typename OptimizerType,
typename FunctionType,
typename MatType>
static typename std::enable_if<callbacks::traits::HasEvaluateSignature<
CallbackType, OptimizerType, FunctionType, MatType>::hasNone,
bool>::type
EvaluateFunction(CallbackType& /* callback */,
OptimizerType& /* optimizer */,
Expand Down Expand Up @@ -304,7 +304,7 @@ class Callback
typename MatType>
static typename std::enable_if<
callbacks::traits::HasEvaluateConstraintSignature<
CallbackType, OptimizerType, FunctionType, MatType>::value,
CallbackType, OptimizerType, FunctionType, MatType>::hasBool,
bool>::type
EvaluateConstraintFunction(CallbackType& callback,
OptimizerType& optimizer,
Expand All @@ -313,17 +313,37 @@ class Callback
const size_t constraint,
const double constraintValue)
{
return (const_cast<CallbackType&>(callback).EvaluateConstraint(
optimizer, function, coordinates, constraint, constraintValue), false);
return const_cast<CallbackType&>(callback).EvaluateConstraint(
optimizer, function, coordinates, constraint, constraintValue);
}

template<typename CallbackType,
typename OptimizerType,
typename FunctionType,
typename MatType>
static typename std::enable_if<
!callbacks::traits::HasEvaluateConstraintSignature<
CallbackType, OptimizerType, FunctionType, MatType>::value,
callbacks::traits::HasEvaluateConstraintSignature<
CallbackType, OptimizerType, FunctionType, MatType>::hasVoid,
bool>::type
EvaluateConstraintFunction(CallbackType& callback,
OptimizerType& optimizer,
FunctionType& function,
const MatType& coordinates,
const size_t constraint,
const double constraintValue)
{
const_cast<CallbackType&>(callback).EvaluateConstraint(
optimizer, function, coordinates, constraint, constraintValue);
return false;
}

template<typename CallbackType,
typename OptimizerType,
typename FunctionType,
typename MatType>
static typename std::enable_if<
callbacks::traits::HasEvaluateConstraintSignature<
CallbackType, OptimizerType, FunctionType, MatType>::hasNone,
bool>::type
EvaluateConstraintFunction(CallbackType& /* callback */,
OptimizerType& /* optimizer */,
Expand Down Expand Up @@ -380,25 +400,44 @@ class Callback
typename MatType,
typename GradType>
static typename std::enable_if<callbacks::traits::HasGradientSignature<
CallbackType, OptimizerType, FunctionType, MatType, GradType>::value,
CallbackType, OptimizerType, FunctionType, MatType, GradType>::hasBool,
bool>::type
GradientFunction(CallbackType& callback,
OptimizerType& optimizer,
FunctionType& function,
const MatType& coordinates,
GradType& gradient)
{
return (const_cast<CallbackType&>(callback).Gradient(
optimizer, function, coordinates, gradient), false);
return const_cast<CallbackType&>(callback).Gradient(optimizer, function,
coordinates, gradient);
}

template<typename CallbackType,
typename OptimizerType,
typename FunctionType,
typename MatType,
typename GradType>
static typename std::enable_if<!callbacks::traits::HasGradientSignature<
CallbackType, OptimizerType, FunctionType, MatType, GradType>::value,
static typename std::enable_if<callbacks::traits::HasGradientSignature<
CallbackType, OptimizerType, FunctionType, MatType, GradType>::hasVoid,
bool>::type
GradientFunction(CallbackType& callback,
OptimizerType& optimizer,
FunctionType& function,
const MatType& coordinates,
GradType& gradient)
{
const_cast<CallbackType&>(callback).Gradient(
optimizer, function, coordinates, gradient);
return false;
}

template<typename CallbackType,
typename OptimizerType,
typename FunctionType,
typename MatType,
typename GradType>
static typename std::enable_if<callbacks::traits::HasGradientSignature<
CallbackType, OptimizerType, FunctionType, MatType, GradType>::hasNone,
bool>::type
GradientFunction(CallbackType& /* callback */,
OptimizerType& /* optimizer */,
Expand Down Expand Up @@ -451,7 +490,27 @@ class Callback
typename GradType>
static typename std::enable_if<
callbacks::traits::HasGradientConstraintSignature<
CallbackType, OptimizerType, FunctionType, MatType, GradType>::value,
CallbackType, OptimizerType, FunctionType, MatType, GradType>::hasBool,
bool>::type
GradientConstraintFunction(CallbackType& callback,
OptimizerType& optimizer,
FunctionType& function,
const MatType& coordinates,
const size_t constraint,
GradType& gradient)
{
return const_cast<CallbackType&>(callback).GradientConstraint(optimizer,
function, coordinates, constraint, gradient);
}

template<typename CallbackType,
typename OptimizerType,
typename FunctionType,
typename MatType,
typename GradType>
static typename std::enable_if<
callbacks::traits::HasGradientConstraintSignature<
CallbackType, OptimizerType, FunctionType, MatType, GradType>::hasVoid,
bool>::type
GradientConstraintFunction(CallbackType& callback,
OptimizerType& optimizer,
Expand All @@ -460,8 +519,9 @@ class Callback
const size_t constraint,
GradType& gradient)
{
return (const_cast<CallbackType&>(callback).GradientConstraint(
optimizer, function, coordinates, constraint, gradient), false);
const_cast<CallbackType&>(callback).GradientConstraint(
optimizer, function, coordinates, constraint, gradient);
return false;
}

template<typename CallbackType,
Expand All @@ -470,8 +530,8 @@ class Callback
typename MatType,
typename GradType>
static typename std::enable_if<
!callbacks::traits::HasGradientConstraintSignature<
CallbackType, OptimizerType, FunctionType, MatType, GradType>::value,
callbacks::traits::HasGradientConstraintSignature<
CallbackType, OptimizerType, FunctionType, MatType, GradType>::hasNone,
bool>::type
GradientConstraintFunction(CallbackType& /* callback */,
OptimizerType& /* optimizer */,
Expand Down Expand Up @@ -563,24 +623,42 @@ class Callback
typename FunctionType,
typename MatType>
static typename std::enable_if<callbacks::traits::HasBeginEpochSignature<
CallbackType, OptimizerType, FunctionType, MatType>::value, bool>::type
CallbackType, OptimizerType, FunctionType, MatType>::hasBool, bool>::type
BeginEpochFunction(CallbackType& callback,
OptimizerType& optimizer,
FunctionType& function,
const MatType& coordinates,
const size_t epoch,
const double objective)
{
return const_cast<CallbackType&>(callback).BeginEpoch(
optimizer, function, coordinates, epoch, objective);
}

template<typename CallbackType,
typename OptimizerType,
typename FunctionType,
typename MatType>
static typename std::enable_if<callbacks::traits::HasBeginEpochSignature<
CallbackType, OptimizerType, FunctionType, MatType>::hasVoid, bool>::type
BeginEpochFunction(CallbackType& callback,
OptimizerType& optimizer,
FunctionType& function,
const MatType& coordinates,
const size_t epoch,
const double objective)
{
return (const_cast<CallbackType&>(callback).BeginEpoch(
optimizer, function, coordinates, epoch, objective), false);
const_cast<CallbackType&>(callback).BeginEpoch(
optimizer, function, coordinates, epoch, objective);
return false;
}

template<typename CallbackType,
typename OptimizerType,
typename FunctionType,
typename MatType>
static typename std::enable_if<!callbacks::traits::HasBeginEpochSignature<
CallbackType, OptimizerType, FunctionType, MatType>::value, bool>::type
static typename std::enable_if<callbacks::traits::HasBeginEpochSignature<
CallbackType, OptimizerType, FunctionType, MatType>::hasNone, bool>::type
BeginEpochFunction(CallbackType& /* callback */,
OptimizerType& /* optimizer */,
FunctionType& /* function */,
Expand Down Expand Up @@ -768,6 +846,32 @@ class Callback
MatType& /* coordinates */)
{ return false; }

/**
* Iterate over the callbacks and invoke the StepTaken() callback if it
* exists.
*
* @param optimizer The optimizer used to update the function.
* @param function Function to optimize.
* @param coordinates Starting point.
* @param callbacks The callbacks container.
*/
template<typename OptimizerType,
typename FunctionType,
typename MatType,
typename... CallbackTypes>
static bool StepTaken(OptimizerType& optimizer,
FunctionType& function,
MatType& coordinates,
CallbackTypes&... callbacks)
{
// This will return immediately once a callback returns true.
bool result = false;
(void)std::initializer_list<bool>{ result =
result || Callback::StepTakenFunction(callbacks, optimizer,
function, coordinates)... };
return result;
}

/**
* Invoke the GenerationalStepTaken() callback if it exists.
* Specialization for MultiObjective case.
Expand Down Expand Up @@ -819,7 +923,6 @@ class Callback
{
const_cast<CallbackType&>(callback).GenerationalStepTaken(
optimizer, function, coordinates, objectives, frontIndices);

return false;
}

Expand All @@ -841,32 +944,6 @@ class Callback
IndicesType& /* frontIndices */)
{ return false; }

/**
* Iterate over the callbacks and invoke the StepTaken() callback if it
* exists.
*
* @param optimizer The optimizer used to update the function.
* @param function Function to optimize.
* @param coordinates Starting point.
* @param callbacks The callbacks container.
*/
template<typename OptimizerType,
typename FunctionType,
typename MatType,
typename... CallbackTypes>
static bool StepTaken(OptimizerType& optimizer,
FunctionType& function,
MatType& coordinates,
CallbackTypes&... callbacks)
{
// This will return immediately once a callback returns true.
bool result = false;
(void)std::initializer_list<bool>{ result =
result || Callback::StepTakenFunction(callbacks, optimizer,
function, coordinates)... };
return result;
}

/**
* Iterate over the callbacks and invoke the GenerationalStepTaken() callback if it
* exists.
Expand Down
Loading

0 comments on commit 7ade3e9

Please sign in to comment.