Skip to content

Commit

Permalink
Use Zygote.jacobian etc. (#128)
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion authored Feb 7, 2024
1 parent 462225b commit 2bc18d0
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 7 deletions.
20 changes: 18 additions & 2 deletions ext/AbstractDifferentiationZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,27 @@ else
using ..Zygote: Zygote
end

AD.ZygoteBackend() = AD.ReverseRuleConfigBackend(Zygote.ZygoteRuleConfig())

# Context should not persist between different AD calls: fixes #69
function AD.ruleconfig(::AD.ReverseRuleConfigBackend{<:Zygote.ZygoteRuleConfig})
return Zygote.ZygoteRuleConfig()
end

function AD.value_and_pullback_function(::AD.ZygoteBackend, f, args...)
return Zygote.pullback(f, args...)
end

AD.gradient(::AD.ZygoteBackend, f, args...) = Zygote.gradient(f, args...)
function AD.value_and_gradient(::AD.ZygoteBackend, f, args...)
res = Zygote.withgradient(f, args...)
return res.val, res.grad
end

AD.jacobian(::AD.ZygoteBackend, f, args...) = Zygote.jacobian(f, args...)
function AD.value_and_jacobian(::AD.ZygoteBackend, f, args...)
res = Zygote.withjacobian(f, args...)
return res.val, res.grad
end

AD.hessian(::AD.ZygoteBackend, f, arg) = Zygote.hessian(f, arg)

end # module
7 changes: 4 additions & 3 deletions src/backends.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,14 @@ end
ruleconfig(ba::ReverseRuleConfigBackend) = ba.ruleconfig

"""
ZygoteBackend()
ZygoteBackend
Create an AD backend that uses reverse mode with [Zygote.jl](https://github.com/FluxML/Zygote.jl).
It is a special case of [`ReverseRuleConfigBackend`](@ref).
Alternatively, you can perform AD with Zygote using a special [`ReverseRuleConfigBackend`](@ref), namely `ReverseRuleConfigBackend(Zygote.ZygoteRuleConfig())`.
Note, however, that the behaviour of this backend is not equivalent to `ZygoteBackend()` since the former uses a generic implementation of jacobian etc. for ChainRules-compatible AD backends whereas `ZygoteBackend` uses implementations in Zygote.jl.
!!! note
To be able to use this backend, you have to load Zygote.
"""
function ZygoteBackend end
struct ZygoteBackend <: AbstractReverseMode end
26 changes: 24 additions & 2 deletions test/ruleconfig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ using Test
using Zygote

@testset "ReverseRuleConfigBackend(ZygoteRuleConfig())" begin
backends = [@inferred(AD.ZygoteBackend())]
backends = [
@inferred(AD.ZygoteBackend()),
@inferred(AD.ReverseRuleConfigBackend(Zygote.ZygoteRuleConfig()))
]
@testset for backend in backends
@testset "Derivative" begin
test_derivatives(backend)
Expand Down Expand Up @@ -37,7 +40,7 @@ using Zygote

# issue #69
@testset "Zygote context" begin
ad = AD.ZygoteBackend()
ad = AD.ReverseRuleConfigBackend(Zygote.ZygoteRuleConfig())

# example in #69: context is not mutated
@test ad.ruleconfig.context.cache === nothing
Expand All @@ -56,6 +59,13 @@ using Zygote
end
@test AD.jacobian(ad, f, [1, 2, 3], 3) ==
([6.0 0.0 0.0; 0.0 6.0 0.0; 0.0 0.0 6.0], [2.0, 4.0, 6.0])

# With `AD.ZygoteBackend`:
ad = AD.ZygoteBackend()
@test AD.derivative(ad, exp, 1.0) === (exp(1.0),)
@test AD.derivative(ad, exp, 1.0) === (exp(1.0),)
@test AD.jacobian(ad, f, [1, 2, 3], 3) ==
([6.0 0.0 0.0; 0.0 6.0 0.0; 0.0 0.0 6.0], [2.0, 4.0, 6.0])
end

# issue #57
Expand All @@ -68,5 +78,17 @@ using Zygote

@test_logs Zygote.gradient(myfunc, 1) # nothing is logged
@test_logs AD.derivative(AD.ZygoteBackend(), myfunc, 1) # nothing is logged
@test_logs AD.derivative(
AD.ReverseRuleConfigBackend(Zygote.ZygoteRuleConfig()), myfunc, 1
) # nothing is logged
end

# issue #54
@testset "allocations of jacobian" begin
f(x) = x .^ 2
x = rand(100)
ad = AD.ZygoteBackend()
@test AD.jacobian(ad, f, x) == Zygote.jacobian(f, x)
@test @allocated(AD.jacobian(ad, f, x)) == @allocated(Zygote.jacobian(f, x))
end
end

0 comments on commit 2bc18d0

Please sign in to comment.