Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Make auglag generic and reusable with all solvers #833

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Optimization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ include("utils.jl")
include("state.jl")
include("lbfgsb.jl")
include("sophia.jl")
include("auglag.jl")

export solve

Expand Down
179 changes: 179 additions & 0 deletions src/auglag.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
@kwdef struct AugLag
inner
τ = 0.5
γ = 10.0
λmin = -1e20
λmax = 1e20
μmin = 0.0
μmax = 1e20
ϵ = 1e-8
end

SciMLBase.supports_opt_cache_interface(::AugLag) = true
SciMLBase.allowsbounds(::AugLag) = true
SciMLBase.requiresgradient(::AugLag) = true
SciMLBase.allowsconstraints(::AugLag) = true
SciMLBase.requiresconsjac(::AugLag) = true

function __map_optimizer_args(cache::Optimization.OptimizationCache, opt::AugLag;
callback = nothing,
maxiters::Union{Number, Nothing} = nothing,
maxtime::Union{Number, Nothing} = nothing,
abstol::Union{Number, Nothing} = nothing,
reltol::Union{Number, Nothing} = nothing,
verbose::Bool = false,
kwargs...)
if !isnothing(abstol)
@warn "common abstol is currently not used by $(opt)"
end
if !isnothing(maxtime)
@warn "common abstol is currently not used by $(opt)"
end

mapped_args = (;)

if cache.lb !== nothing && cache.ub !== nothing
mapped_args = (; mapped_args..., lb = cache.lb, ub = cache.ub)
end

if !isnothing(maxiters)
mapped_args = (; mapped_args..., maxiter = maxiters)
end

if !isnothing(reltol)
mapped_args = (; mapped_args..., pgtol = reltol)
end

return mapped_args
end

function SciMLBase.__solve(cache::OptimizationCache{
F,
RC,
LB,
UB,
LC,
UC,
S,
O,
D,
P,
C
}) where {
F,
RC,
LB,
UB,
LC,
UC,
S,
O <:
AugLag,
D,
P,
C
}
maxiters = Optimization._check_and_convert_maxiters(cache.solver_args.maxiters)

local x

solver_kwargs = __map_optimizer_args(cache, cache.opt; maxiters, cache.solver_args...)

if !isnothing(cache.f.cons)
eq_inds = [cache.lcons[i] == cache.ucons[i] for i in eachindex(cache.lcons)]
ineq_inds = (!).(eq_inds)

τ = cache.opt.τ
γ = cache.opt.γ
λmin = cache.opt.λmin
λmax = cache.opt.λmax
μmin = cache.opt.μmin
μmax = cache.opt.μmax
ϵ = cache.opt.ϵ

λ = zeros(eltype(cache.u0), sum(eq_inds))
μ = zeros(eltype(cache.u0), sum(ineq_inds))

cons_tmp = zeros(eltype(cache.u0), length(cache.lcons))
cache.f.cons(cons_tmp, cache.u0)
ρ = max(1e-6, min(10, 2 * (abs(cache.f(cache.u0, iterate(cache.p)[1]))) / norm(cons_tmp)))

_loss = function (θ, p = cache.p)
x = cache.f(θ, p)
cons_tmp .= zero(eltype(θ))
cache.f.cons(cons_tmp, θ)
cons_tmp[eq_inds] .= cons_tmp[eq_inds] - cache.lcons[eq_inds]
cons_tmp[ineq_inds] .= cons_tmp[ineq_inds] .- cache.ucons[ineq_inds]
opt_state = Optimization.OptimizationState(u = θ, objective = x[1])
if cache.callback(opt_state, x...)
error("Optimization halted by callback.")
end
return x[1] + sum(@. λ * cons_tmp[eq_inds] + ρ / 2 * (cons_tmp[eq_inds] .^ 2)) +
1 / (2 * ρ) * sum((max.(Ref(0.0), μ .+ (ρ .* cons_tmp[ineq_inds]))) .^ 2)
end

prev_eqcons = zero(λ)
θ = cache.u0
β = max.(cons_tmp[ineq_inds], Ref(0.0))
prevβ = zero(β)
eqidxs = [eq_inds[i] > 0 ? i : nothing for i in eachindex(ineq_inds)]
ineqidxs = [ineq_inds[i] > 0 ? i : nothing for i in eachindex(ineq_inds)]
eqidxs = eqidxs[eqidxs .!= nothing]
ineqidxs = ineqidxs[ineqidxs .!= nothing]
function aug_grad(G, θ, p)
cache.f.grad(G, θ, p)
if !isnothing(cache.f.cons_jac_prototype)
J = Float64.(cache.f.cons_jac_prototype)
else
J = zeros((length(cache.lcons), length(θ)))
end
cache.f.cons_j(J, θ)
__tmp = zero(cons_tmp)
cache.f.cons(__tmp, θ)
__tmp[eq_inds] .= __tmp[eq_inds] .- cache.lcons[eq_inds]
__tmp[ineq_inds] .= __tmp[ineq_inds] .- cache.ucons[ineq_inds]
G .+= sum(
λ[i] .* J[idx, :] + ρ * (__tmp[idx] .* J[idx, :])
for (i, idx) in enumerate(eqidxs);
init = zero(G)) #should be jvp
G .+= sum(
1 / ρ * (max.(Ref(0.0), μ[i] .+ (ρ .* __tmp[idx])) .* J[idx, :])
for (i, idx) in enumerate(ineqidxs);
init = zero(G)) #should be jvp
end

opt_ret = ReturnCode.MaxIters
n = length(cache.u0)

augprob = OptimizationProblem(OptimizationFunction(_loss; grad = aug_grad), cache.u0, cache.p)

solver_kwargs = Base.structdiff(solver_kwargs, (; lb = nothing, ub = nothing))

for i in 1:(maxiters/10)
prev_eqcons .= cons_tmp[eq_inds] .- cache.lcons[eq_inds]
prevβ .= copy(β)
res = solve(augprob, cache.opt.inner, maxiters = maxiters / 10)
θ = res.u
cons_tmp .= 0.0
cache.f.cons(cons_tmp, θ)
λ = max.(min.(λmax, λ .+ ρ * (cons_tmp[eq_inds] .- cache.lcons[eq_inds])), λmin)
β = max.(cons_tmp[ineq_inds], -1 .* μ ./ ρ)
μ = min.(μmax, max.(μ .+ ρ * cons_tmp[ineq_inds], μmin))
if max(norm(cons_tmp[eq_inds] .- cache.lcons[eq_inds], Inf), norm(β, Inf)) >
τ * max(norm(prev_eqcons, Inf), norm(prevβ, Inf))
ρ = γ * ρ
end
if norm(
(cons_tmp[eq_inds] .- cache.lcons[eq_inds]) ./ cons_tmp[eq_inds], Inf) <
ϵ && norm(β, Inf) < ϵ
opt_ret = ReturnCode.Success
break
end
end
stats = Optimization.OptimizationStats(; iterations = maxiters,
time = 0.0, fevals = maxiters, gevals = maxiters)
return SciMLBase.build_solution(
cache, cache.opt, θ, x,
stats = stats, retcode = opt_ret)
end
end
2 changes: 1 addition & 1 deletion src/lbfgsb.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ function __map_optimizer_args(cache::Optimization.OptimizationCache, opt::LBFGS;
@warn "common abstol is currently not used by $(opt)"
end
if !isnothing(maxtime)
@warn "common abstol is currently not used by $(opt)"
@warn "common maxtime is currently not used by $(opt)"
end

mapped_args = (;)
Expand Down
28 changes: 28 additions & 0 deletions test/lbfgsb.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,31 @@ prob = OptimizationProblem(optf, x0, lcons = [1.0, -Inf],
ub = [1.0, 1.0])
@time res = solve(prob, Optimization.LBFGS(), maxiters = 100)
@test res.retcode == SciMLBase.ReturnCode.Success

using MLUtils, OptimizationOptimisers

x0 = -pi:0.001:pi
y0 = sin.(x0)
data = MLUtils.DataLoader((x0, y0), batchsize = 100)
function loss(coeffs, data)
ypred = [evalpoly(data[1][i], coeffs) for i in eachindex(data[1])]
return sum(abs2, ypred .- data[2])
end

function cons1(res, coeffs, p = nothing)
res[1] = coeffs[1] * coeffs[5] - 1
return nothing
end

optf = OptimizationFunction(loss, AutoSparseForwardDiff(), cons = cons1)
callback = (st, l) -> (@show l; return false)

prob = OptimizationProblem(optf, rand(5), (x0, y0), lcons = [-0.5], ucons = [0.5], lb = [-10.0, -10.0, -10.0, -10.0, -10.0], ub = [10.0, 10.0, 10.0, 10.0, 10.0])
opt1 = solve(prob, Optimization.LBFGS(), maxiters = 1000, callback = callback)

prob = OptimizationProblem(optf, rand(5), data, lcons = [0.0], ucons = [0.0], lb = [-10.0, -10.0, -10.0, -10.0, -10.0], ub = [10.0, 10.0, 10.0, 10.0, 10.0])
opt = solve(prob, Optimization.AugLag(; inner = Adam()), maxiters = 500, callback = callback)

optf1 = OptimizationFunction(loss, AutoSparseForwardDiff())
prob1 = OptimizationProblem(optf1, rand(5), data)
sol1 = solve(prob1, OptimizationOptimisers.Adam(), maxiters = 1000, callback = callback)
Loading