diff --git a/Project.toml b/Project.toml index 0d06a992..45394b15 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesTestUtils" uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a" -version = "0.5.3" +version = "0.5.4" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -14,4 +14,11 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" ChainRulesCore = "0.9.1" Compat = "3" FiniteDifferences = "0.11.2" +Quaternions = "0.4" julia = "1" + +[extras] +Quaternions = "94ee1d12-ae83-5a48-8b1c-48b8ff168ae0" + +[targets] +test = ["Quaternions"] diff --git a/src/testers.jl b/src/testers.jl index 5ea83f01..4f4544ee 100644 --- a/src/testers.jl +++ b/src/testers.jl @@ -100,6 +100,21 @@ function _make_jvp_call(fdm, f, xs, ẋs, ignores) return jvp(fdm, f2, sigargs...) end +""" + _basis_vectors(x::T) -> Vector{T} + +Get a set of basis (co)tangent vectors for `x`. + +This function assumes that the (co)tangent vectors are of the same type as `x` and requires +that `FiniteDifferences.to_vec` be implemented for inputs of the same type as `x`. +""" +function _basis_vectors(x) + v, from_vec = FiniteDifferences.to_vec(x) + basis_coords = Diagonal(ones(eltype(v), length(v))) + basis_vecs = [from_vec(@view basis_coords[:, i]) for i in axes(basis_coords, 2)] + return basis_vecs +end + """ test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), fkwargs=NamedTuple(), kwargs...) @@ -112,61 +127,48 @@ at input point `z` to confirm that there are correct `frule` and `rrule`s provid `fkwargs` are passed to `f` as keyword arguments. All keyword arguments except for `fdm` and `fkwargs` are passed to `isapprox`. + +To use this tester for a scalar type `MyNumber <: Number`, +`FiniteDifferences.to_vec(::MyNumber)` must be implemented. """ function test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(), kwargs...) _ensure_not_running_on_functor(f, "test_scalar") - # z = x + im * y - # Ω = u(x, y) + im * v(x, y) Ω = f(z; fkwargs...) + Δzs = _basis_vectors(z) + Δx = first(Δzs) + ΔΩs = _basis_vectors(Ω) + # test jacobian using forward mode - Δx = one(z) - @testset "$f at $z, with tangent $Δx" begin - # check ∂u_∂x and (if Ω is complex) ∂v_∂x via forward mode - frule_test(f, (z, Δx); rtol=rtol, atol=atol, fdm=fdm, fkwargs=fkwargs, kwargs...) - if z isa Complex - # check that same tangent is produced for tangent 1.0 and 1.0 + 0.0im + @testset "$f at $z, with tangent $Δz" for (i, Δz) in enumerate(Δzs) + frule_test(f, (z, Δz); rtol=rtol, atol=atol, fdm=fdm, fkwargs=fkwargs, kwargs...) + if !isa(Δz, Real) && i == 1 + # check that same tangent is produced for tangent real(one(z)) and one(z) @test isapprox( - frule((Zero(), real(Δx)), f, z; fkwargs...)[2], - frule((Zero(), Δx), f, z; fkwargs...)[2], + frule((Zero(), real(Δz)), f, z; fkwargs...)[2], + frule((Zero(), Δz), f, z; fkwargs...)[2], rtol=rtol, atol=atol, kwargs..., ) end end - if z isa Complex - Δy = one(z) * im - @testset "$f at $z, with tangent $Δy" begin - # check ∂u_∂y and (if Ω is complex) ∂v_∂y via forward mode - frule_test(f, (z, Δy); rtol=rtol, atol=atol, fdm=fdm, fkwargs=fkwargs, kwargs...) - end - end # test jacobian transpose using reverse mode - Δu = one(Ω) - @testset "$f at $z, with cotangent $Δu" begin - # check ∂u_∂x and (if z is complex) ∂u_∂y via reverse mode - rrule_test(f, Δu, (z, Δx); rtol=rtol, atol=atol, fdm=fdm, fkwargs=fkwargs, kwargs...) - if Ω isa Complex - # check that same cotangent is produced for cotangent 1.0 and 1.0 + 0.0im + @testset "$f at $z, with cotangent $ΔΩ" for (i, ΔΩ) in enumerate(ΔΩs) + rrule_test(f, ΔΩ, (z, Δx); rtol=rtol, atol=atol, fdm=fdm, fkwargs=fkwargs, kwargs...) + if !isa(ΔΩ, Real) && i == 1 + # check that same cotangent is produced for cotangent real(one(Ω)) and one(Ω) back = rrule(f, z)[2] @test isapprox( - extern(back(real(Δu))[2]), - extern(back(Δu)[2]), + extern(back(real(ΔΩ))[2]), + extern(back(ΔΩ)[2]), rtol=rtol, atol=atol, kwargs..., ) end end - if Ω isa Complex - Δv = one(Ω) * im - @testset "$f at $z, with cotangent $Δv" begin - # check ∂v_∂x and (if z is complex) ∂v_∂y via reverse mode - rrule_test(f, Δv, (z, Δx); rtol=rtol, atol=atol, fdm=fdm, fkwargs=fkwargs, kwargs...) - end - end end """ diff --git a/test/runtests.jl b/test/runtests.jl index 79ba9496..80dd8118 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,8 @@ using ChainRulesCore using ChainRulesTestUtils +using FiniteDifferences using LinearAlgebra +using Quaternions using Random using Test diff --git a/test/testers.jl b/test/testers.jl index bcfcc96f..82753284 100644 --- a/test/testers.jl +++ b/test/testers.jl @@ -8,6 +8,8 @@ sinconj(x) = sin(x) primalapprox(x) = x +quatfun(q::Quaternion) = Quaternion(q.v3, 2 * q.v1, 3 * q.s, 4 * q.v2) + @testset "testers.jl" begin @testset "test_scalar" begin @testset "Ensure correct rules succeed" begin @@ -367,4 +369,29 @@ primalapprox(x) = x @test fails(()->rrule_test(my_identity2, 4.1, (2.2, 3.3))) end end + + @testset "test quaternion non-standard scalar" begin + function FiniteDifferences.to_vec(q::Quaternion) + function Quaternion_from_vec(q_vec) + return Quaternion(q_vec[1], q_vec[2], q_vec[3], q_vec[4]) + end + return [q.s, q.v1, q.v2, q.v3], Quaternion_from_vec + end + + function ChainRulesCore.frule((_, Δq), ::typeof(quatfun), q) + ∂q = Quaternion(Δq) + return quatfun(q), Quaternion(∂q.v3, 2 * ∂q.v1, 3 * ∂q.s, 4 * ∂q.v2) + end + + function ChainRulesCore.rrule(::typeof(quatfun), q) + function quatfun_pullback(ΔΩ) + ∂Ω = Quaternion(ΔΩ) + return (NO_FIELDS, Quaternion(3 * ∂Ω.v2, 2 * ∂Ω.v1, 4 * ∂Ω.v3, ∂Ω.s)) + end + return quatfun(q), quatfun_pullback + end + + q = quatrand() + test_scalar(quatfun, q) + end end