Skip to content

Commit

Permalink
Merge pull request #577 from SciML/revcons
Browse files Browse the repository at this point in the history
Add constraints support to ReverseDiff and Zygote
  • Loading branch information
Vaibhavdixit02 authored Aug 23, 2023
2 parents 422eacb + e62eedf commit 6fdfdcc
Show file tree
Hide file tree
Showing 5 changed files with 243 additions and 81 deletions.
96 changes: 47 additions & 49 deletions docs/src/optimization_packages/optimisers.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,21 @@ Pkg.add("OptimizationOptimisers");
In addition to the optimisation algorithms provided by the Optimisers.jl package this subpackage
also provides the Sophia optimisation algorithm.


## Local Unconstrained Optimizers

- Sophia: Based on the recent paper https://arxiv.org/abs/2305.14342. It incorporates second order information
in the form of the diagonal of the Hessian matrix hence avoiding the need to compute the complete hessian. It has been shown to converge faster than other first order methods such as Adam and SGD.

- `Sophia`: Based on the recent paper https://arxiv.org/abs/2305.14342. It incorporates second order information
in the form of the diagonal of the Hessian matrix hence avoiding the need to compute the complete hessian. It has been shown to converge faster than other first order methods such as Adam and SGD.
+ `solve(problem, Sophia(; η, βs, ϵ, λ, k, ρ))`

+ `η` is the learning rate
+ `βs` are the decay of momentums
+ `ϵ` is the epsilon value
+ `λ` is the weight decay parameter
+ `k` is the number of iterations to re-compute the diagonal of the Hessian matrix
+ `ρ` is the momentum
+ Defaults:

* `η = 0.001`
* `βs = (0.9, 0.999)`
* `ϵ = 1e-8`
Expand All @@ -36,139 +35,138 @@ also provides the Sophia optimisation algorithm.
* `ρ = 0.04`

- [`Optimisers.Descent`](https://fluxml.ai/Optimisers.jl/dev/api/#Optimisers.Descent): **Classic gradient descent optimizer with learning rate**

+ `solve(problem, Descent(η))`

+ `η` is the learning rate
+ Defaults:

* `η = 0.1`

- [`Optimisers.Momentum`](https://fluxml.ai/Optimisers.jl/dev/api/#Optimisers.Momentum): **Classic gradient descent optimizer with learning rate and momentum**

+ `solve(problem, Momentum(η, ρ))`

+ `η` is the learning rate
+ `ρ` is the momentum
+ Defaults:

* `η = 0.01`
* `ρ = 0.9`
- [`Optimisers.Nesterov`](https://fluxml.ai/Optimisers.jl/dev/api/#Optimisers.Nesterov): **Gradient descent optimizer with learning rate and Nesterov momentum**

+ `solve(problem, Nesterov(η, ρ))`

+ `η` is the learning rate
+ `ρ` is the Nesterov momentum
+ Defaults:

* `η = 0.01`
* `ρ = 0.9`
- [`Optimisers.RMSProp`](https://fluxml.ai/Optimisers.jl/dev/api/#Optimisers.RMSProp): **RMSProp optimizer**

+ `solve(problem, RMSProp(η, ρ))`

+ `η` is the learning rate
+ `ρ` is the momentum
+ Defaults:

* `η = 0.001`
* `ρ = 0.9`
- [`Optimisers.Adam`](https://fluxml.ai/Optimisers.jl/dev/api/#Optimisers.Adam): **Adam optimizer**

+ `solve(problem, Adam(η, β::Tuple))`

+ `η` is the learning rate
+ `β::Tuple` is the decay of momentums
+ Defaults:

* `η = 0.001`
* `β::Tuple = (0.9, 0.999)`
- [`Optimisers.RAdam`](https://fluxml.ai/Optimisers.jl/dev/api/#Optimisers.RAdam): **Rectified Adam optimizer**

+ `solve(problem, RAdam(η, β::Tuple))`

+ `η` is the learning rate
+ `β::Tuple` is the decay of momentums
+ Defaults:

* `η = 0.001`
* `β::Tuple = (0.9, 0.999)`
- [`Optimisers.RAdam`](https://fluxml.ai/Optimisers.jl/dev/api/#Optimisers.OAdam): **Optimistic Adam optimizer**

+ `solve(problem, OAdam(η, β::Tuple))`

+ `η` is the learning rate
+ `β::Tuple` is the decay of momentums
+ Defaults:

* `η = 0.001`
* `β::Tuple = (0.5, 0.999)`
- [`Optimisers.AdaMax`](https://fluxml.ai/Optimisers.jl/dev/api/#Optimisers.AdaMax): **AdaMax optimizer**

+ `solve(problem, AdaMax(η, β::Tuple))`

+ `η` is the learning rate
+ `β::Tuple` is the decay of momentums
+ Defaults:

* `η = 0.001`
* `β::Tuple = (0.9, 0.999)`
- [`Optimisers.ADAGrad`](https://fluxml.ai/Optimisers.jl/dev/api/#Optimisers.ADAGrad): **ADAGrad optimizer**

+ `solve(problem, ADAGrad(η))`

+ `η` is the learning rate
+ Defaults:

* `η = 0.1`
- [`Optimisers.ADADelta`](https://fluxml.ai/Optimisers.jl/dev/api/#Optimisers.ADADelta): **ADADelta optimizer**

+ `solve(problem, ADADelta(ρ))`

+ `ρ` is the gradient decay factor
+ Defaults:

* `ρ = 0.9`
- [`Optimisers.AMSGrad`](https://fluxml.ai/Optimisers.jl/dev/api/#Optimisers.ADAGrad): **AMSGrad optimizer**

+ `solve(problem, AMSGrad(η, β::Tuple))`

+ `η` is the learning rate
+ `β::Tuple` is the decay of momentums
+ Defaults:

* `η = 0.001`
* `β::Tuple = (0.9, 0.999)`
- [`Optimisers.NAdam`](https://fluxml.ai/Optimisers.jl/dev/api/#Optimisers.NAdam): **Nesterov variant of the Adam optimizer**

+ `solve(problem, NAdam(η, β::Tuple))`

+ `η` is the learning rate
+ `β::Tuple` is the decay of momentums
+ Defaults:

* `η = 0.001`
* `β::Tuple = (0.9, 0.999)`
- [`Optimisers.AdamW`](https://fluxml.ai/Optimisers.jl/dev/api/#Optimisers.AdamW): **AdamW optimizer**

+ `solve(problem, AdamW(η, β::Tuple))`

+ `η` is the learning rate
+ `β::Tuple` is the decay of momentums
+ `decay` is the decay to weights
+ Defaults:

* `η = 0.001`
* `β::Tuple = (0.9, 0.999)`
* `decay = 0`
- [`Optimisers.ADABelief`](https://fluxml.ai/Optimisers.jl/dev/api/#Optimisers.ADABelief): **ADABelief variant of Adam**

+ `solve(problem, ADABelief(η, β::Tuple))`

+ `η` is the learning rate
+ `β::Tuple` is the decay of momentums
+ Defaults:

* `η = 0.001`
* `β::Tuple = (0.9, 0.999)`
9 changes: 4 additions & 5 deletions ext/OptimizationEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,10 @@ function Optimization.instantiate_function(f::OptimizationFunction{true}, x,
el .= zeros(length(θ))
end
Enzyme.autodiff(Enzyme.Forward,
f2,
Enzyme.BatchDuplicated(θ, vdθ),
Enzyme.BatchDuplicated(bθ, vdbθ),
Const(fncs[i]),
)
f2,
Enzyme.BatchDuplicated(θ, vdθ),
Enzyme.BatchDuplicated(bθ, vdbθ),
Const(fncs[i]))

for j in eachindex(θ)
res[i][j, :] .= vdbθ[j]
Expand Down
99 changes: 83 additions & 16 deletions ext/OptimizationReversediffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,11 @@ isdefined(Base, :get_extension) ? (using ReverseDiff, ReverseDiff.ForwardDiff) :
function Optimization.instantiate_function(f, x, adtype::AutoReverseDiff,
p = SciMLBase.NullParameters(),
num_cons = 0)
num_cons != 0 && error("AutoReverseDiff does not currently support constraints")

_f = (θ, args...) -> first(f.f(θ, p, args...))

if f.grad === nothing
grad = (res, θ, args...) -> ReverseDiff.gradient!(res, x -> _f(x, args...), θ,
ReverseDiff.GradientConfig(θ))
cfg = ReverseDiff.GradientConfig(x)
grad = (res, θ, args...) -> ReverseDiff.gradient!(res, x -> _f(x, args...), θ, cfg)
else
grad = (G, θ, args...) -> f.grad(G, θ, p, args...)
end
Expand All @@ -41,22 +39,55 @@ function Optimization.instantiate_function(f, x, adtype::AutoReverseDiff,
hv = f.hv
end

return OptimizationFunction{false}(f, adtype; grad = grad, hess = hess, hv = hv,
cons = nothing, cons_j = nothing, cons_h = nothing,
if f.cons === nothing
cons = nothing
else
cons = (res, θ) -> f.cons(res, θ, p)
cons_oop = (x) -> (_res = zeros(eltype(x), num_cons); cons(_res, x); _res)
end

if cons !== nothing && f.cons_j === nothing
cjconfig = ReverseDiff.JacobianConfig(x)
cons_j = function (J, θ)
ReverseDiff.jacobian!(J, cons_oop, θ, cjconfig)
end
else
cons_j = (J, θ) -> f.cons_j(J, θ, p)
end

if cons !== nothing && f.cons_h === nothing
fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons]
cons_h = function (res, θ)
for i in 1:num_cons
res[i] .= ForwardDiff.jacobian(θ) do θ
ReverseDiff.gradient(fncs[i], θ)
end
end
end
else
cons_h = (res, θ) -> f.cons_h(res, θ, p)
end

if f.lag_h === nothing
lag_h = nothing # Consider implementing this
else
lag_h = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, p)
end
return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv,
cons = cons, cons_j = cons_j, cons_h = cons_h,
hess_prototype = f.hess_prototype,
cons_jac_prototype = nothing,
cons_hess_prototype = nothing)
cons_jac_prototype = f.cons_jac_prototype,
cons_hess_prototype = f.cons_hess_prototype,
lag_h, f.lag_hess_prototype)
end

function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
adtype::AutoReverseDiff, num_cons = 0)
num_cons != 0 && error("AutoReverseDiff does not currently support constraints")

_f = (θ, args...) -> first(f.f(θ, cache.p, args...))

if f.grad === nothing
grad = (res, θ, args...) -> ReverseDiff.gradient!(res, x -> _f(x, args...), θ,
ReverseDiff.GradientConfig(θ))
cfg = ReverseDiff.GradientConfig(cache.u0)
grad = (res, θ, args...) -> ReverseDiff.gradient!(res, x -> _f(x, args...), θ)
else
grad = (G, θ, args...) -> f.grad(G, θ, cache.p, args...)
end
Expand All @@ -82,11 +113,47 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
hv = f.hv
end

return OptimizationFunction{false}(f, adtype; grad = grad, hess = hess, hv = hv,
cons = nothing, cons_j = nothing, cons_h = nothing,
if f.cons === nothing
cons = nothing
else
cons = (res, θ) -> f.cons(res, θ, cache.p)
cons_oop = (x) -> (_res = zeros(eltype(x), num_cons); cons(_res, x); _res)
end

if cons !== nothing && f.cons_j === nothing
cjconfig = ReverseDiff.JacobianConfig(cache.u0)
cons_j = function (J, θ)
ReverseDiff.jacobian!(J, cons_oop, θ, cjconfig)
end
else
cons_j = (J, θ) -> f.cons_j(J, θ, cache.p)
end

if cons !== nothing && f.cons_h === nothing
fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons]
cons_h = function (res, θ)
for i in 1:num_cons
res[i] .= ForwardDiff.jacobian(θ) do θ
ReverseDiff.gradient(fncs[i], θ)
end
end
end
else
cons_h = (res, θ) -> f.cons_h(res, θ, cache.p)
end

if f.lag_h === nothing
lag_h = nothing # Consider implementing this
else
lag_h = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, cache.p)
end

return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv,
cons = cons, cons_j = cons_j, cons_h = cons_h,
hess_prototype = f.hess_prototype,
cons_jac_prototype = nothing,
cons_hess_prototype = nothing)
cons_jac_prototype = f.cons_jac_prototype,
cons_hess_prototype = f.cons_hess_prototype,
lag_h, f.lag_hess_prototype)
end

end
Loading

0 comments on commit 6fdfdcc

Please sign in to comment.