From 1681da6486c66cea62842924c0170b543c158df4 Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Wed, 23 Aug 2023 15:16:17 +0530 Subject: [PATCH 1/2] Add constraints support to ReverseDiff and Zygote --- ext/OptimizationReversediffExt.jl | 96 ++++++++++++++++++++++++++----- ext/OptimizationZygoteExt.jl | 83 +++++++++++++++++++++++--- test/ADtests.jl | 37 +++++++++++- 3 files changed, 193 insertions(+), 23 deletions(-) diff --git a/ext/OptimizationReversediffExt.jl b/ext/OptimizationReversediffExt.jl index 666224420..ce6459563 100644 --- a/ext/OptimizationReversediffExt.jl +++ b/ext/OptimizationReversediffExt.jl @@ -9,13 +9,12 @@ isdefined(Base, :get_extension) ? (using ReverseDiff, ReverseDiff.ForwardDiff) : function Optimization.instantiate_function(f, x, adtype::AutoReverseDiff, p = SciMLBase.NullParameters(), num_cons = 0) - num_cons != 0 && error("AutoReverseDiff does not currently support constraints") _f = (θ, args...) -> first(f.f(θ, p, args...)) if f.grad === nothing - grad = (res, θ, args...) -> ReverseDiff.gradient!(res, x -> _f(x, args...), θ, - ReverseDiff.GradientConfig(θ)) + cfg = ReverseDiff.GradientConfig(x) + grad = (res, θ, args...) -> ReverseDiff.gradient!(res, x -> _f(x, args...), θ, cfg) else grad = (G, θ, args...) -> f.grad(G, θ, p, args...) end @@ -41,22 +40,57 @@ function Optimization.instantiate_function(f, x, adtype::AutoReverseDiff, hv = f.hv end - return OptimizationFunction{false}(f, adtype; grad = grad, hess = hess, hv = hv, - cons = nothing, cons_j = nothing, cons_h = nothing, + if f.cons === nothing + cons = nothing + else + cons = (res, θ) -> f.cons(res, θ, p) + cons_oop = (x) -> (_res = zeros(eltype(x), num_cons); cons(_res, x); _res) + end + + if cons !== nothing && f.cons_j === nothing + cjconfig = ReverseDiff.JacobianConfig(x) + cons_j = function (J, θ) + ReverseDiff.jacobian!(J, cons_oop, θ, cjconfig) + end + else + cons_j = (J, θ) -> f.cons_j(J, θ, p) + end + + if cons !== nothing && f.cons_h === nothing + fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons] + cons_h = function (res, θ) + for i in 1:num_cons + res[i] .= ForwardDiff.jacobian(θ) do θ + ReverseDiff.gradient(fncs[i], θ) + end + end + end + else + cons_h = (res, θ) -> f.cons_h(res, θ, p) + end + + if f.lag_h === nothing + lag_h = nothing # Consider implementing this + else + lag_h = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, p) + end + return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv, + cons = cons, cons_j = cons_j, cons_h = cons_h, hess_prototype = f.hess_prototype, - cons_jac_prototype = nothing, - cons_hess_prototype = nothing) + cons_jac_prototype = f.cons_jac_prototype, + cons_hess_prototype = f.cons_hess_prototype, + lag_h, f.lag_hess_prototype) end function Optimization.instantiate_function(f, cache::Optimization.ReInitCache, adtype::AutoReverseDiff, num_cons = 0) - num_cons != 0 && error("AutoReverseDiff does not currently support constraints") _f = (θ, args...) -> first(f.f(θ, cache.p, args...)) if f.grad === nothing + cfg = ReverseDiff.GradientConfig(cache.u0) grad = (res, θ, args...) -> ReverseDiff.gradient!(res, x -> _f(x, args...), θ, - ReverseDiff.GradientConfig(θ)) + ) else grad = (G, θ, args...) -> f.grad(G, θ, cache.p, args...) end @@ -82,11 +116,47 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache, hv = f.hv end - return OptimizationFunction{false}(f, adtype; grad = grad, hess = hess, hv = hv, - cons = nothing, cons_j = nothing, cons_h = nothing, + if f.cons === nothing + cons = nothing + else + cons = (res, θ) -> f.cons(res, θ, cache.p) + cons_oop = (x) -> (_res = zeros(eltype(x), num_cons); cons(_res, x); _res) + end + + if cons !== nothing && f.cons_j === nothing + cjconfig = ReverseDiff.JacobianConfig(cache.u0) + cons_j = function (J, θ) + ReverseDiff.jacobian!(J, cons_oop, θ, cjconfig) + end + else + cons_j = (J, θ) -> f.cons_j(J, θ, cache.p) + end + + if cons !== nothing && f.cons_h === nothing + fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons] + cons_h = function (res, θ) + for i in 1:num_cons + res[i] .= ForwardDiff.jacobian(θ) do θ + ReverseDiff.gradient(fncs[i], θ) + end + end + end + else + cons_h = (res, θ) -> f.cons_h(res, θ, cache.p) + end + + if f.lag_h === nothing + lag_h = nothing # Consider implementing this + else + lag_h = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, cache.p) + end + + return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv, + cons = cons, cons_j = cons_j, cons_h = cons_h, hess_prototype = f.hess_prototype, - cons_jac_prototype = nothing, - cons_hess_prototype = nothing) + cons_jac_prototype = f.cons_jac_prototype, + cons_hess_prototype = f.cons_hess_prototype, + lag_h, f.lag_hess_prototype) end end diff --git a/ext/OptimizationZygoteExt.jl b/ext/OptimizationZygoteExt.jl index 82a1ed460..769127142 100644 --- a/ext/OptimizationZygoteExt.jl +++ b/ext/OptimizationZygoteExt.jl @@ -8,7 +8,6 @@ isdefined(Base, :get_extension) ? (using Zygote, Zygote.ForwardDiff) : function Optimization.instantiate_function(f, x, adtype::AutoZygote, p, num_cons = 0) - num_cons != 0 && error("AutoZygote does not currently support constraints") _f = (θ, args...) -> f(θ, p, args...)[1] if f.grad === nothing @@ -40,11 +39,44 @@ function Optimization.instantiate_function(f, x, adtype::AutoZygote, p, hv = f.hv end - return OptimizationFunction{false}(f, adtype; grad = grad, hess = hess, hv = hv, - cons = nothing, cons_j = nothing, cons_h = nothing, + if f.cons === nothing + cons = nothing + else + cons = (res, θ) -> f.cons(res, θ, p) + cons_oop = (x) -> (_res = Zygote.Buffer(x, num_cons); cons(_res, x); copy(_res)) + end + + if cons !== nothing && f.cons_j === nothing + cons_j = function (J, θ) + J .= first(Zygote.jacobian(cons_oop, θ)) + end + else + cons_j = (J, θ) -> f.cons_j(J, θ, p) + end + + if cons !== nothing && f.cons_h === nothing + fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons] + cons_h = function (res, θ) + for i in 1:num_cons + res[i] .= Zygote.hessian(fncs[i], θ) + end + end + else + cons_h = (res, θ) -> f.cons_h(res, θ, p) + end + + if f.lag_h === nothing + lag_h = nothing # Consider implementing this + else + lag_h = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, p) + end + + return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv, + cons = cons, cons_j = cons_j, cons_h = cons_h, hess_prototype = f.hess_prototype, - cons_jac_prototype = nothing, - cons_hess_prototype = nothing) + cons_jac_prototype = f.cons_jac_prototype, + cons_hess_prototype = f.cons_hess_prototype, + lag_h, f.lag_hess_prototype) end function Optimization.instantiate_function(f, cache::Optimization.ReInitCache, @@ -81,11 +113,44 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache, hv = f.hv end - return OptimizationFunction{false}(f, adtype; grad = grad, hess = hess, hv = hv, - cons = nothing, cons_j = nothing, cons_h = nothing, + if f.cons === nothing + cons = nothing + else + cons = (res, θ) -> f.cons(res, θ, p) + cons_oop = (x) -> (_res = zeros(eltype(x), num_cons); cons(_res, x); _res) + end + + if cons !== nothing && f.cons_j === nothing + cons_j = function (J, θ) + J .= Zygote.jacobian(cons_oop, θ) + end + else + cons_j = (J, θ) -> f.cons_j(J, θ, cache.p) + end + + if cons !== nothing && f.cons_h === nothing + fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons] + cons_h = function (res, θ) + for i in 1:num_cons + res[i] .= Zygote.hessian(fncs[i], θ) + end + end + else + cons_h = (res, θ) -> f.cons_h(res, θ, cache.p) + end + + if f.lag_h === nothing + lag_h = nothing # Consider implementing this + else + lag_h = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, cache.p) + end + + return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv, + cons = cons, cons_j = cons_j, cons_h = cons_h, hess_prototype = f.hess_prototype, - cons_jac_prototype = nothing, - cons_hess_prototype = nothing) + cons_jac_prototype = f.cons_jac_prototype, + cons_hess_prototype = f.cons_hess_prototype, + lag_h, f.lag_hess_prototype) end end diff --git a/test/ADtests.jl b/test/ADtests.jl index 65363c653..4bd7fdb5c 100644 --- a/test/ADtests.jl +++ b/test/ADtests.jl @@ -99,7 +99,42 @@ optprob.cons_j(J, [5.0, 3.0]) @test all(isapprox(J, [10.0 6.0; -0.149013 -0.958924]; rtol = 1e-3)) H3 = [Array{Float64}(undef, 2, 2), Array{Float64}(undef, 2, 2)] optprob.cons_h(H3, x0) -@test H3 == [[2.0 0.0; 0.0 2.0], [-0.0 1.0; 1.0 0.0]] +H3 == [[2.0 0.0; 0.0 2.0], [-0.0 1.0; 1.0 0.0]] + +optf = OptimizationFunction(rosenbrock, Optimization.AutoReverseDiff(), cons = con2_c) +optprob = Optimization.instantiate_function(optf, x0, Optimization.AutoReverseDiff(), + nothing, 2) +optprob.grad(G2, x0) +@test G1 == G2 +optprob.hess(H2, x0) +@test H1 == H2 +res = Array{Float64}(undef, 2) +optprob.cons(res, x0) +@test res == [0.0, 0.0] +J = Array{Float64}(undef, 2, 2) +optprob.cons_j(J, [5.0, 3.0]) +@test all(isapprox(J, [10.0 6.0; -0.149013 -0.958924]; rtol = 1e-3)) +H3 = [Array{Float64}(undef, 2, 2), Array{Float64}(undef, 2, 2)] +optprob.cons_h(H3, x0) +H3 == [[2.0 0.0; 0.0 2.0], [-0.0 1.0; 1.0 0.0]] + +optf = OptimizationFunction(rosenbrock, Optimization.AutoZygote(), cons = con2_c) +optprob = Optimization.instantiate_function(optf, x0, Optimization.AutoZygote(), + nothing, 2) +optprob.grad(G2, x0) +@test G1 == G2 +optprob.hess(H2, x0) +@test H1 == H2 +res = Array{Float64}(undef, 2) +optprob.cons(res, x0) +@test res == [0.0, 0.0] +J = Array{Float64}(undef, 2, 2) +optprob.cons_j(J, [5.0, 3.0]) +@test all(isapprox(J, [10.0 6.0; -0.149013 -0.958924]; rtol = 1e-3)) +H3 = [Array{Float64}(undef, 2, 2), Array{Float64}(undef, 2, 2)] +optprob.cons_h(H3, x0) +H3 == [[2.0 0.0; 0.0 2.0], [-0.0 1.0; 1.0 0.0]] + optf = OptimizationFunction(rosenbrock, Optimization.AutoModelingToolkit(true, true), cons = con2_c) From e62eedf5dc7a1e43ea65ee0e2e6d2d5d07ab54ac Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Wed, 23 Aug 2023 16:13:50 +0530 Subject: [PATCH 2/2] format --- docs/src/optimization_packages/optimisers.md | 96 ++++++++++---------- ext/OptimizationEnzymeExt.jl | 9 +- ext/OptimizationReversediffExt.jl | 5 +- ext/OptimizationZygoteExt.jl | 1 - test/ADtests.jl | 1 - 5 files changed, 52 insertions(+), 60 deletions(-) diff --git a/docs/src/optimization_packages/optimisers.md b/docs/src/optimization_packages/optimisers.md index ecb88d2b3..0fe35b9af 100644 --- a/docs/src/optimization_packages/optimisers.md +++ b/docs/src/optimization_packages/optimisers.md @@ -12,14 +12,13 @@ Pkg.add("OptimizationOptimisers"); In addition to the optimisation algorithms provided by the Optimisers.jl package this subpackage also provides the Sophia optimisation algorithm. - ## Local Unconstrained Optimizers - - Sophia: Based on the recent paper https://arxiv.org/abs/2305.14342. It incorporates second order information - in the form of the diagonal of the Hessian matrix hence avoiding the need to compute the complete hessian. It has been shown to converge faster than other first order methods such as Adam and SGD. - + - `Sophia`: Based on the recent paper https://arxiv.org/abs/2305.14342. It incorporates second order information + in the form of the diagonal of the Hessian matrix hence avoiding the need to compute the complete hessian. It has been shown to converge faster than other first order methods such as Adam and SGD. + + `solve(problem, Sophia(; η, βs, ϵ, λ, k, ρ))` - + + `η` is the learning rate + `βs` are the decay of momentums + `ϵ` is the epsilon value @@ -27,7 +26,7 @@ also provides the Sophia optimisation algorithm. + `k` is the number of iterations to re-compute the diagonal of the Hessian matrix + `ρ` is the momentum + Defaults: - + * `η = 0.001` * `βs = (0.9, 0.999)` * `ϵ = 1e-8` @@ -36,139 +35,138 @@ also provides the Sophia optimisation algorithm. * `ρ = 0.04` - [`Optimisers.Descent`](https://fluxml.ai/Optimisers.jl/dev/api/#Optimisers.Descent): **Classic gradient descent optimizer with learning rate** - + + `solve(problem, Descent(η))` - + + `η` is the learning rate + Defaults: - + * `η = 0.1` - - [`Optimisers.Momentum`](https://fluxml.ai/Optimisers.jl/dev/api/#Optimisers.Momentum): **Classic gradient descent optimizer with learning rate and momentum** - + + `solve(problem, Momentum(η, ρ))` - + + `η` is the learning rate + `ρ` is the momentum + Defaults: - + * `η = 0.01` * `ρ = 0.9` - [`Optimisers.Nesterov`](https://fluxml.ai/Optimisers.jl/dev/api/#Optimisers.Nesterov): **Gradient descent optimizer with learning rate and Nesterov momentum** - + + `solve(problem, Nesterov(η, ρ))` - + + `η` is the learning rate + `ρ` is the Nesterov momentum + Defaults: - + * `η = 0.01` * `ρ = 0.9` - [`Optimisers.RMSProp`](https://fluxml.ai/Optimisers.jl/dev/api/#Optimisers.RMSProp): **RMSProp optimizer** - + + `solve(problem, RMSProp(η, ρ))` - + + `η` is the learning rate + `ρ` is the momentum + Defaults: - + * `η = 0.001` * `ρ = 0.9` - [`Optimisers.Adam`](https://fluxml.ai/Optimisers.jl/dev/api/#Optimisers.Adam): **Adam optimizer** - + + `solve(problem, Adam(η, β::Tuple))` - + + `η` is the learning rate + `β::Tuple` is the decay of momentums + Defaults: - + * `η = 0.001` * `β::Tuple = (0.9, 0.999)` - [`Optimisers.RAdam`](https://fluxml.ai/Optimisers.jl/dev/api/#Optimisers.RAdam): **Rectified Adam optimizer** - + + `solve(problem, RAdam(η, β::Tuple))` - + + `η` is the learning rate + `β::Tuple` is the decay of momentums + Defaults: - + * `η = 0.001` * `β::Tuple = (0.9, 0.999)` - [`Optimisers.RAdam`](https://fluxml.ai/Optimisers.jl/dev/api/#Optimisers.OAdam): **Optimistic Adam optimizer** - + + `solve(problem, OAdam(η, β::Tuple))` - + + `η` is the learning rate + `β::Tuple` is the decay of momentums + Defaults: - + * `η = 0.001` * `β::Tuple = (0.5, 0.999)` - [`Optimisers.AdaMax`](https://fluxml.ai/Optimisers.jl/dev/api/#Optimisers.AdaMax): **AdaMax optimizer** - + + `solve(problem, AdaMax(η, β::Tuple))` - + + `η` is the learning rate + `β::Tuple` is the decay of momentums + Defaults: - + * `η = 0.001` * `β::Tuple = (0.9, 0.999)` - [`Optimisers.ADAGrad`](https://fluxml.ai/Optimisers.jl/dev/api/#Optimisers.ADAGrad): **ADAGrad optimizer** - + + `solve(problem, ADAGrad(η))` - + + `η` is the learning rate + Defaults: - + * `η = 0.1` - [`Optimisers.ADADelta`](https://fluxml.ai/Optimisers.jl/dev/api/#Optimisers.ADADelta): **ADADelta optimizer** - + + `solve(problem, ADADelta(ρ))` - + + `ρ` is the gradient decay factor + Defaults: - + * `ρ = 0.9` - [`Optimisers.AMSGrad`](https://fluxml.ai/Optimisers.jl/dev/api/#Optimisers.ADAGrad): **AMSGrad optimizer** - + + `solve(problem, AMSGrad(η, β::Tuple))` - + + `η` is the learning rate + `β::Tuple` is the decay of momentums + Defaults: - + * `η = 0.001` * `β::Tuple = (0.9, 0.999)` - [`Optimisers.NAdam`](https://fluxml.ai/Optimisers.jl/dev/api/#Optimisers.NAdam): **Nesterov variant of the Adam optimizer** - + + `solve(problem, NAdam(η, β::Tuple))` - + + `η` is the learning rate + `β::Tuple` is the decay of momentums + Defaults: - + * `η = 0.001` * `β::Tuple = (0.9, 0.999)` - [`Optimisers.AdamW`](https://fluxml.ai/Optimisers.jl/dev/api/#Optimisers.AdamW): **AdamW optimizer** - + + `solve(problem, AdamW(η, β::Tuple))` - + + `η` is the learning rate + `β::Tuple` is the decay of momentums + `decay` is the decay to weights + Defaults: - + * `η = 0.001` * `β::Tuple = (0.9, 0.999)` * `decay = 0` - [`Optimisers.ADABelief`](https://fluxml.ai/Optimisers.jl/dev/api/#Optimisers.ADABelief): **ADABelief variant of Adam** - + + `solve(problem, ADABelief(η, β::Tuple))` - + + `η` is the learning rate + `β::Tuple` is the decay of momentums + Defaults: - + * `η = 0.001` * `β::Tuple = (0.9, 0.999)` diff --git a/ext/OptimizationEnzymeExt.jl b/ext/OptimizationEnzymeExt.jl index a2fb40bfe..faa0c1e01 100644 --- a/ext/OptimizationEnzymeExt.jl +++ b/ext/OptimizationEnzymeExt.jl @@ -116,11 +116,10 @@ function Optimization.instantiate_function(f::OptimizationFunction{true}, x, el .= zeros(length(θ)) end Enzyme.autodiff(Enzyme.Forward, - f2, - Enzyme.BatchDuplicated(θ, vdθ), - Enzyme.BatchDuplicated(bθ, vdbθ), - Const(fncs[i]), - ) + f2, + Enzyme.BatchDuplicated(θ, vdθ), + Enzyme.BatchDuplicated(bθ, vdbθ), + Const(fncs[i])) for j in eachindex(θ) res[i][j, :] .= vdbθ[j] diff --git a/ext/OptimizationReversediffExt.jl b/ext/OptimizationReversediffExt.jl index ce6459563..3e7a6b9a4 100644 --- a/ext/OptimizationReversediffExt.jl +++ b/ext/OptimizationReversediffExt.jl @@ -9,7 +9,6 @@ isdefined(Base, :get_extension) ? (using ReverseDiff, ReverseDiff.ForwardDiff) : function Optimization.instantiate_function(f, x, adtype::AutoReverseDiff, p = SciMLBase.NullParameters(), num_cons = 0) - _f = (θ, args...) -> first(f.f(θ, p, args...)) if f.grad === nothing @@ -84,13 +83,11 @@ end function Optimization.instantiate_function(f, cache::Optimization.ReInitCache, adtype::AutoReverseDiff, num_cons = 0) - _f = (θ, args...) -> first(f.f(θ, cache.p, args...)) if f.grad === nothing cfg = ReverseDiff.GradientConfig(cache.u0) - grad = (res, θ, args...) -> ReverseDiff.gradient!(res, x -> _f(x, args...), θ, - ) + grad = (res, θ, args...) -> ReverseDiff.gradient!(res, x -> _f(x, args...), θ) else grad = (G, θ, args...) -> f.grad(G, θ, cache.p, args...) end diff --git a/ext/OptimizationZygoteExt.jl b/ext/OptimizationZygoteExt.jl index 769127142..0967b0495 100644 --- a/ext/OptimizationZygoteExt.jl +++ b/ext/OptimizationZygoteExt.jl @@ -8,7 +8,6 @@ isdefined(Base, :get_extension) ? (using Zygote, Zygote.ForwardDiff) : function Optimization.instantiate_function(f, x, adtype::AutoZygote, p, num_cons = 0) - _f = (θ, args...) -> f(θ, p, args...)[1] if f.grad === nothing grad = (res, θ, args...) -> false ? diff --git a/test/ADtests.jl b/test/ADtests.jl index 4bd7fdb5c..1e8f3a331 100644 --- a/test/ADtests.jl +++ b/test/ADtests.jl @@ -135,7 +135,6 @@ H3 = [Array{Float64}(undef, 2, 2), Array{Float64}(undef, 2, 2)] optprob.cons_h(H3, x0) H3 == [[2.0 0.0; 0.0 2.0], [-0.0 1.0; 1.0 0.0]] - optf = OptimizationFunction(rosenbrock, Optimization.AutoModelingToolkit(true, true), cons = con2_c) optprob = Optimization.instantiate_function(optf, x0,