Skip to content

Commit

Permalink
Merge pull request #652 from SciML/Vaibhavdixit02-patch-4
Browse files Browse the repository at this point in the history
[WIP]Stats and State
  • Loading branch information
Vaibhavdixit02 authored Jan 5, 2024
2 parents 80d8465 + abfa548 commit 880fc54
Show file tree
Hide file tree
Showing 45 changed files with 210 additions and 126 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Optimization"
uuid = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
version = "3.20.2"
version = "3.21.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -57,7 +57,7 @@ Printf = "1.9"
ProgressLogging = "0.1"
Reexport = "1.2"
ReverseDiff = "1.14"
SciMLBase = "2.11"
SciMLBase = "2.16.3"
SparseArrays = "1.9, 1.10"
SparseDiffTools = "2.14"
SymbolicIndexingInterface = "0.3"
Expand Down
4 changes: 2 additions & 2 deletions lib/OptimizationBBO/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "OptimizationBBO"
uuid = "3e6eede4-6085-4f62-9a71-46d9bc1eb92b"
authors = ["Vaibhav Dixit <[email protected]> and contributors"]
version = "0.1.5"
version = "0.2.0"

[deps]
BlackBoxOptim = "a134a8b2-14d6-55f6-9291-3336d3ab0209"
Expand All @@ -10,7 +10,7 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"

[compat]
BlackBoxOptim = "0.6"
Optimization = "3.15"
Optimization = "3.21"
Reexport = "1.2"
julia = "1"

Expand Down
13 changes: 9 additions & 4 deletions lib/OptimizationBBO/src/OptimizationBBO.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,10 @@ function SciMLBase.__solve(cache::Optimization.OptimizationCache{
if cache.callback === Optimization.DEFAULT_CALLBACK
cb_call = false
else
cb_call = cache.callback(decompose_trace(trace, cache.progress), x...)
n_steps = BlackBoxOptim.num_steps(trace)
curr_u = decompose_trace(trace, cache.progress)
opt_state = Optimization.OptimizationState(iteration = n_steps, u = curr_u, objective = x[1], solver_state = trace)
cb_call = cache.callback(opt_state, x...)
end

if !(cb_call isa Bool)
Expand Down Expand Up @@ -175,11 +178,13 @@ function SciMLBase.__solve(cache::Optimization.OptimizationCache{
t1 = time()

opt_ret = Symbol(opt_res.stop_reason)

stats = Optimization.OptimizationStats(; iterations = opt_res.iterations, time = t1 - t0, fevals = opt_res.f_calls)
SciMLBase.build_solution(cache, cache.opt,
BlackBoxOptim.best_candidate(opt_res),
BlackBoxOptim.best_fitness(opt_res); original = opt_res,
retcode = opt_ret, solve_time = t1 - t0)
BlackBoxOptim.best_fitness(opt_res);
original = opt_res,
retcode = opt_ret,
stats = stats)
end

end
4 changes: 2 additions & 2 deletions lib/OptimizationBBO/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ using Test
@test 10 * sol.objective < l1

fitness_progress_history = []
function cb(best_candidate, fitness)
push!(fitness_progress_history, [best_candidate, fitness])
function cb(state, fitness)
push!(fitness_progress_history, [state.u, fitness])
return false
end
sol = solve(prob, BBO_adaptive_de_rand_1_bin_radiuslimited(), callback = cb)
Expand Down
4 changes: 2 additions & 2 deletions lib/OptimizationCMAEvolutionStrategy/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "OptimizationCMAEvolutionStrategy"
uuid = "bd407f91-200f-4536-9381-e4ba712f53f8"
authors = ["Vaibhav Dixit <[email protected]> and contributors"]
version = "0.1.4"
version = "0.2.0"

[deps]
CMAEvolutionStrategy = "8d3b24bd-414e-49e0-94fb-163cc3a3e411"
Expand All @@ -11,7 +11,7 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
[compat]
julia = "1"
CMAEvolutionStrategy = "0.2"
Optimization = "3.15"
Optimization = "3.21"
Reexport = "1.2"

[extras]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ export CMAEvolutionStrategyOpt
struct CMAEvolutionStrategyOpt end

SciMLBase.allowsbounds(::CMAEvolutionStrategyOpt) = true
SciMLBase.allowscallback(::CMAEvolutionStrategyOpt) = false #looks like `logger` kwarg can be used to pass it, so should be implemented
SciMLBase.supports_opt_cache_interface(opt::CMAEvolutionStrategyOpt) = true

function __map_optimizer_args(prob::OptimizationCache, opt::CMAEvolutionStrategyOpt;
Expand All @@ -23,7 +22,7 @@ function __map_optimizer_args(prob::OptimizationCache, opt::CMAEvolutionStrategy
end

mapped_args = (; lower = prob.lb,
upper = prob.ub)
upper = prob.ub, logger = CMAEvolutionStrategy.BasicLogger(prob.u0; verbosity = 0, callback = callback))

if !isnothing(maxiters)
mapped_args = (; mapped_args..., maxiter = maxiters)
Expand Down Expand Up @@ -74,12 +73,18 @@ function SciMLBase.__solve(cache::OptimizationCache{

cur, state = iterate(cache.data)

function _cb(trace)
cb_call = cache.callback(decompose_trace(trace).metadata["x"], trace.value...)
function _cb(opt, y, fvals, perm)
curr_u = opt.logger.xbest[end]
opt_state = Optimization.OptimizationState(; iteration = length(opt.logger.fmedian),
u = curr_u,
objective = opt.logger.fbest[end],
solver_state = opt.logger)

cb_call = cache.callback(opt_state, x...)
if !(cb_call isa Bool)
error("The callback should return a boolean `halt` for whether to stop the optimization process.")
end
cur, state = iterate(data, state)
cur, state = iterate(cache.data, state)
cb_call
end

Expand All @@ -100,11 +105,12 @@ function SciMLBase.__solve(cache::OptimizationCache{
t1 = time()

opt_ret = opt_res.stop.reason

stats = Optimization.OptimizationStats(; iterations = length(opt_res.logger.fmedian), time = t1 - t0, fevals = length(opt_res.logger.fmedian))
SciMLBase.build_solution(cache, cache.opt,
opt_res.logger.xbest[end],
opt_res.logger.fbest[end]; original = opt_res,
retcode = opt_ret, solve_time = t1 - t0)
retcode = opt_ret,
stats = stats)
end

end
8 changes: 8 additions & 0 deletions lib/OptimizationCMAEvolutionStrategy/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,12 @@ using Test
prob = OptimizationProblem(f, x0, _p, lb = [-1.0, -1.0], ub = [0.8, 0.8])
sol = solve(prob, CMAEvolutionStrategyOpt())
@test 10 * sol.objective < l1

function cb(state, args...)
if state.iteration %10 == 0
println(state.u)
end
return false
end
sol = solve(prob, CMAEvolutionStrategyOpt(), callback = cb, maxiters = 100)
end
4 changes: 2 additions & 2 deletions lib/OptimizationEvolutionary/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "OptimizationEvolutionary"
uuid = "cb963754-43f6-435e-8d4b-99009ff27753"
authors = ["Vaibhav Dixit <[email protected]> and contributors"]
version = "0.1.3"
version = "0.2.0"

[deps]
Evolutionary = "86b6b26d-c046-49b6-aa0b-5f0f74682bd6"
Expand All @@ -11,7 +11,7 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
[compat]
julia = "1"
Evolutionary = "0.11"
Optimization = "3.15"
Optimization = "3.21"
Reexport = "1.2"

[extras]
Expand Down
13 changes: 10 additions & 3 deletions lib/OptimizationEvolutionary/src/OptimizationEvolutionary.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,12 @@ function SciMLBase.__solve(cache::OptimizationCache{
cur, state = iterate(cache.data)

function _cb(trace)
cb_call = cache.callback(decompose_trace(trace).metadata["x"], trace.value...)
curr_u = decompose_trace(trace).metadata["x"][end]
opt_state = Optimization.OptimizationState(; iteration = decompose_trace(trace).iteration,
u = curr_u,
objective = x[1],
solver_state = trace)
cb_call = cache.callback(opt_state, trace.value...)
if !(cb_call isa Bool)
error("The callback should return a boolean `halt` for whether to stop the optimization process.")
end
Expand Down Expand Up @@ -127,11 +132,13 @@ function SciMLBase.__solve(cache::OptimizationCache{
end
t1 = time()
opt_ret = Symbol(Evolutionary.converged(opt_res))

stats = Optimization.OptimizationStats(; iterations = opt_res.iterations
, time = t1 - t0, fevals = opt_res.f_calls)
SciMLBase.build_solution(cache, cache.opt,
Evolutionary.minimizer(opt_res),
Evolutionary.minimum(opt_res); original = opt_res,
retcode = opt_ret, solve_time = t1 - t0)
retcode = opt_ret,
stats = stats)
end

end
8 changes: 8 additions & 0 deletions lib/OptimizationEvolutionary/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,12 @@ Random.seed!(1234)
res = zeros(1)
cons_circ(res, sol.u, nothing)
@test sol.objective < l1

function cb(state, args...)
if state.iteration %10 == 0
println(state.u)
end
return false
end
sol = solve(prob, CMAES= 40, λ = 100), callback = cb, maxiters = 100)
end
4 changes: 2 additions & 2 deletions lib/OptimizationFlux/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "OptimizationFlux"
uuid = "253f991c-a7b2-45f8-8852-8b9a9df78a86"
authors = ["Vaibhav Dixit <[email protected]> and contributors"]
version = "0.1.5"
version = "0.2.0"

[deps]
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Expand All @@ -15,7 +15,7 @@ julia = "1"
Flux = "0.13, 0.14"
ProgressLogging = "0.1"
Reexport = "1.2"
Optimization = "3.15"
Optimization = "3.21"

[extras]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand Down
18 changes: 14 additions & 4 deletions lib/OptimizationFlux/src/OptimizationFlux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
P,
C,
}
local i
if cache.data != Optimization.DEFAULT_DATA
maxiters = length(cache.data)
data = cache.data
Expand All @@ -65,7 +66,11 @@ function SciMLBase.__solve(cache::OptimizationCache{
for (i, d) in enumerate(data)
cache.f.grad(G, θ, d...)
x = cache.f(θ, cache.p, d...)
cb_call = cache.callback(θ, x...)
opt_state = Optimization.OptimizationState(; iteration = i,
u = θ,
objective = x[1],
solver_state = opt)
cb_call = cache.callback(opt_state, x...)
if !(cb_call isa Bool)
error("The callback should return a boolean `halt` for whether to stop the optimization process. Please see the sciml_train documentation for information.")
elseif cb_call
Expand All @@ -84,7 +89,11 @@ function SciMLBase.__solve(cache::OptimizationCache{
opt = min_opt
x = min_err
θ = min_θ
cache.callback(θ, x...)
opt_state = Optimization.OptimizationState(; iteration = i,
u = θ,
objective = x[1],
solver_state = opt)
cache.callback(opt_state, x...)
break
end
end
Expand All @@ -93,8 +102,9 @@ function SciMLBase.__solve(cache::OptimizationCache{
end

t1 = time()

SciMLBase.build_solution(cache, opt, θ, x[1], solve_time = t1 - t0)
stats = Optimization.OptimizationStats(; iterations = maxiters,
time = t1 - t0, fevals = maxiters, gevals = maxiters)
SciMLBase.build_solution(cache, opt, θ, x[1], stats = stats)
# here should be build_solution to create the output message
end

Expand Down
8 changes: 8 additions & 0 deletions lib/OptimizationFlux/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,12 @@ using Test
sol = Optimization.solve!(cache)
@test sol.u[2.0] atol=1e-3
end

function cb(state, args...)
if state.iteration % 10 == 0
println(state.u)
end
return false
end
sol = solve(prob, Flux.Adam(0.1), callback = cb, maxiters = 100, progress = false)
end
4 changes: 2 additions & 2 deletions lib/OptimizationGCMAES/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "OptimizationGCMAES"
uuid = "6f0a0517-dbc2-4a7a-8a20-99ae7f27e911"
authors = ["Vaibhav Dixit <[email protected]> and contributors"]
version = "0.1.6"
version = "0.2.0"

[deps]
GCMAES = "4aa9d100-eb0f-11e8-15f1-25748831eb3b"
Expand All @@ -10,7 +10,7 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"

[compat]
julia = "1"
Optimization = "3.15"
Optimization = "3.21"
GCMAES = "0.1"
Reexport = "1.2"

Expand Down
8 changes: 5 additions & 3 deletions lib/OptimizationGCMAES/src/OptimizationGCMAES.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ function __map_optimizer_args(cache::OptimizationCache, opt::GCMAESOpt;
end

if !(isnothing(maxtime))
@warn "common maxtime is currently not used by $(opt)"
mapped_args = (; mapped_args..., maxtime = maxtime)
end

if !isnothing(abstol)
Expand Down Expand Up @@ -114,10 +114,12 @@ function SciMLBase.__solve(cache::OptimizationCache{
cache.ub; opt_args...)
end
t1 = time()

stats = Optimization.OptimizationStats(; iterations = maxiters === nothing ? 0 : maxiters,
time = t1 - t0)
SciMLBase.build_solution(cache, cache.opt,
opt_xmin, opt_fmin; retcode = Symbol(Bool(opt_ret)),
solve_time = t1 - t0)
stats = stats
)
end

end
4 changes: 2 additions & 2 deletions lib/OptimizationMOI/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "OptimizationMOI"
uuid = "fd9f6733-72f4-499f-8506-86b2bdd0dea1"
authors = ["Vaibhav Dixit <[email protected]> and contributors"]
version = "0.2.0"
version = "0.3.0"

[deps]
Ipopt_jll = "9cc047cb-c261-5740-88fc-0cf96f7bdcc7"
Expand All @@ -18,7 +18,7 @@ Ipopt_jll = "=300.1400.400"
Juniper = "0.9"
MathOptInterface = "1"
ModelingToolkit = "8.74"
Optimization = "3.15"
Optimization = "3.21"
Reexport = "1.2"
SymbolicIndexingInterface = "0.3"
Symbolics = "5"
Expand Down
4 changes: 0 additions & 4 deletions lib/OptimizationMOI/src/OptimizationMOI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,6 @@ function SciMLBase.allowsconstraints(opt::Union{MOI.AbstractOptimizer,
MOI.OptimizerWithAttributes})
true
end
# function SciMLBase.allowscallback(opt::Union{MOI.AbstractOptimizer,
# MOI.OptimizerWithAttributes})
# false
# end

function _create_new_optimizer(opt::MOI.OptimizerWithAttributes)
return _create_new_optimizer(MOI.instantiate(opt, with_bridge_type = Float64))
Expand Down
4 changes: 3 additions & 1 deletion lib/OptimizationMOI/src/moi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,12 +161,14 @@ function SciMLBase.__solve(cache::MOIOptimizationCache)
minimum = NaN
opt_ret = SciMLBase.ReturnCode.Default
end
stats = Optimization.OptimizationStats()
return SciMLBase.build_solution(cache,
cache.opt,
minimizer,
minimum;
original = opt_setup,
retcode = opt_ret)
retcode = opt_ret,
stats = stats)
end

function get_moi_function(expr)
Expand Down
Loading

0 comments on commit 880fc54

Please sign in to comment.