Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fearless taylor version #109

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"

[[deps.ChainRules]]
deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "Statistics", "StructArrays"]
git-tree-sha1 = "c46adabdd0348f0ee8de91142cfc4a72a613ac0a"
git-tree-sha1 = "fdde4d8a31cf82b1d136cf6cb53924e8744a832b"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.46.1"
version = "1.47.0"

[[deps.ChainRulesCore]]
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
Expand Down Expand Up @@ -265,9 +265,9 @@ version = "1.10.0"

[[deps.StaticArrays]]
deps = ["LinearAlgebra", "Random", "StaticArraysCore", "Statistics"]
git-tree-sha1 = "129703d62117c374c4f2db6d13a027741c46eafd"
git-tree-sha1 = "cee507162ecbb677450f20058ca83bd559b6b752"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "1.5.13"
version = "1.5.14"

[[deps.StaticArraysCore]]
git-tree-sha1 = "6b7ba252635a5eff6a0b0664a41ee140a1c9e72a"
Expand Down
10 changes: 5 additions & 5 deletions src/extra_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ function (::∂⃖{N})(f::typeof(*), args...) where {N}
end
return z
else
∂⃖p = ∂⃖{minus1(N)}()
∂⃖p = ∂⃖{N-1}()
@destruct z, z̄ = ∂⃖p(rrule_times, f, args...)
if z === nothing
return ∂⃖recurse{N}()(f, args...)
Expand Down Expand Up @@ -130,15 +130,15 @@ end
struct NonDiffEven{N, O, P}; end
struct NonDiffOdd{N, O, P}; end

(::NonDiffOdd{N, O, P})(Δ) where {N, O, P} = (ntuple(_->ZeroTangent(), N), NonDiffEven{N, plus1(O), P}())
(::NonDiffEven{N, O, P})(Δ...) where {N, O, P} = (ZeroTangent(), NonDiffOdd{N, plus1(O), P}())
(::NonDiffOdd{N, O, P})(Δ) where {N, O, P} = (ntuple(_->ZeroTangent(), N), NonDiffEven{N, O+1, P}())
(::NonDiffEven{N, O, P})(Δ...) where {N, O, P} = (ZeroTangent(), NonDiffOdd{N, O+1, P}())
(::NonDiffOdd{N, O, O})(Δ) where {N, O} = ntuple(_->ZeroTangent(), N)

# This should not happen
(::NonDiffEven{N, O, O})(Δ...) where {N, O} = error()

@Base.pure function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(Core.apply_type), head, args...)
Core.apply_type(head, args...), NonDiffOdd{plus1(plus1(length(args))), 1, 1}()
@Base.assume_effects :total function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(Core.apply_type), head, args...)
Core.apply_type(head, args...), NonDiffOdd{length(args)+2, 1, 1}()
end

function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(Core.tuple), args...)
Expand Down
12 changes: 5 additions & 7 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ dx(x) = error("Cotangent space not defined for `$(typeof(x))`. Try a real-valued
For `x` in a one dimensional manifold, map x to the trivial, unital, 1st order
tangent bundle. It should hold that `∀x ⟨∂x(x), dx(x)⟩ = 1`
"""
∂x(x::Real) = ExplicitTangentBundle{1}(x, (one(x),))
∂x(x::Real) = TaylorBundle{1}(x, (one(x),))
∂x(x) = error("Tangent space not defined for `$(typeof(x)).")

struct ∂xⁿ{N}; end
Expand Down Expand Up @@ -143,11 +143,9 @@ Base.show(io::IO, f::PrimeDerivativeBack{N}) where {N} = print(io, f.f, "'"^N)

# This improves performance for nested derivatives by short cutting some
# recursion into the PrimeDerivative constructor
@Base.pure minus1(N) = N - 1
@Base.pure plus1(N) = N + 1
lower_pd(f::PrimeDerivativeBack{N,T}) where {N,T} = PrimeDerivativeBack{minus1(N),T}(getfield(f, :f))
lower_pd(f::PrimeDerivativeBack{N,T}) where {N,T} = PrimeDerivativeBack{N-1,T}(getfield(f, :f))
lower_pd(f::PrimeDerivativeBack{1}) = getfield(f, :f)
raise_pd(f::PrimeDerivativeBack{N,T}) where {N,T} = PrimeDerivativeBack{plus1(N),T}(getfield(f, :f))
raise_pd(f::PrimeDerivativeBack{N,T}) where {N,T} = PrimeDerivativeBack{N+1,T}(getfield(f, :f))

ChainRulesCore.rrule(::typeof(lower_pd), f) = lower_pd(f), Δ->(ZeroTangent(), Δ)
ChainRulesCore.rrule(::typeof(raise_pd), f) = raise_pd(f), Δ->(ZeroTangent(), Δ)
Expand All @@ -170,8 +168,8 @@ end
PrimeDerivativeFwd(f) = PrimeDerivativeFwd{1, typeof(f)}(f)
PrimeDerivativeFwd(f::PrimeDerivativeFwd{N, T}) where {N, T} = raise_pd(f)

lower_pd(f::PrimeDerivativeFwd{N,T}) where {N,T} = (error(); PrimeDerivativeFwd{minus1(N),T}(getfield(f, :f)))
raise_pd(f::PrimeDerivativeFwd{N,T}) where {N,T} = PrimeDerivativeFwd{plus1(N),T}(getfield(f, :f))
lower_pd(f::PrimeDerivativeFwd{N,T}) where {N,T} = (error(); PrimeDerivativeFwd{N-1,T}(getfield(f, :f)))
raise_pd(f::PrimeDerivativeFwd{N,T}) where {N,T} = PrimeDerivativeFwd{N+1,T}(getfield(f, :f))

(f::PrimeDerivativeFwd{0})(x) = getfield(f, :f)(x)

Expand Down
45 changes: 17 additions & 28 deletions src/jet.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
struct Jet{T, N}
struct Jet{S, T, N}

Represents the truncated (N-1)-th order Taylor series

Expand All @@ -15,8 +15,8 @@ For a jet `j`, several operations are supported:
derivatives. Mathematically this corresponds to an infinitessimal ball
around `a`.
"""
struct Jet{T, N}
a::T
struct Jet{S, T, N}
a::S
f₀::T
fₙ::NTuple{N, T}
end
Expand All @@ -25,13 +25,13 @@ function ChainRulesCore.rrule(::typeof(Base.getproperty), j::Jet, s)
error("Raw getproperty not allowed in AD code")
end

function Base.:+(j1::Jet{T, N}, j2::Jet{T, N}) where {T, N}
function Base.:+(j1::Jet{S, T, N}, j2::Jet{S, T, N}) where {S, T, N}
@assert j1.a === j2.a
Jet{T, N}(j1.a, j1.f₀ + j2.f₀, map(+, j1.fₙ, j2.fₙ))
Jet{S, T, N}(j1.a, j1.f₀ + j2.f₀, map(+, j1.fₙ, j2.fₙ))
end

function Base.:+(j::Jet{T, N}, x::T) where {T, N}
Jet{T, N}(j.a, j.f₀+x, j.fₙ)
function Base.:+(j::Jet{S, T, N}, x::T) where {S, T, N}
Jet{S, T, N}(j.a, j.f₀+x, j.fₙ)
end

struct One; end
Expand All @@ -44,28 +44,28 @@ function ChainRulesCore.rrule(::typeof(+), j::Jet, x::One)
j + x, Δ->(NoTangent(), One(), ZeroTangent())
end

function Base.zero(j::Jet{T, N}) where {T, N}
function Base.zero(j::Jet{S, T, N}) where {S, T, N}
let z = zero(j[0])
Jet{T, N}(j.a, z,
Jet{S, T, N}(j.a, z,
ntuple(_->z, N))
end
end
function ChainRulesCore.rrule(::typeof(Base.zero), j::Jet)
zero(j), Δ->(NoTangent(), ZeroTangent())
end

function Base.getindex(j::Jet{T, N}, i::Integer) where {T, N}
function Base.getindex(j::Jet{S, T, N}, i::Integer) where {S, T, N}
(0 <= i <= N) || throw(BoundsError(j, i))
i == 0 && return j.f₀
return j.fₙ[i]
end

function deriv(j::Jet{T, N}) where {T, N}
Jet{T, N-1}(j.a, j.fₙ[1], Base.tail(j.fₙ))
function deriv(j::Jet{S, T, N}) where {S, T, N}
Jet{S, T, N-1}(j.a, j.fₙ[1], Base.tail(j.fₙ))
end

function integrate(j::Jet{T, N}) where {T, N}
Jet{T, N+1}(j.a, zero(j.f₀), tuple(j.f₀, j.fₙ...))
function integrate(j::Jet{S, T, N}) where {S, T, N}
Jet{S, T, N+1}(j.a, zero(j.f₀), tuple(j.f₀, j.fₙ...))
end

deriv(::NoTangent) = NoTangent()
Expand Down Expand Up @@ -187,9 +187,8 @@ function (∂⃖ₙ::∂⃖{N})(::typeof(map), f, a::Array) where {N}
∂f = ∂☆{N}()(ZeroBundle{N}(f),
TaylorBundle{N}(x,
(one(x), (zero(x) for i = 1:(N-1))...,)))
@assert isa(∂f, TaylorBundle) || isa(∂f, ExplicitTangentBundle{1})
Jet{typeof(x), N}(x, ∂f.primal,
isa(∂f, ExplicitTangentBundle) ? ∂f.tangent.partials : ∂f.tangent.coeffs)
@assert isa(∂f, TaylorBundle)
Jet{typeof(x), typeof(x), N}(x, ∂f.primal, ∂f.tangent.coeffs)
end
∂⃖ₙ(mapev, js, a)
end
Expand Down Expand Up @@ -239,7 +238,7 @@ expressions for the t′ᵢ that are hopefully easier on the compiler.
end...)
end

@generated function (j::Jet{T, N} where T)(x::TaylorBundle{M}) where {N, M}
@generated function (j::Jet{S, T, N} where {S, T})(x::TaylorBundle{M}) where {N, M}
O = min(M,N)
quote
domain_check(j, x.primal)
Expand All @@ -248,13 +247,3 @@ end
($((:(jet_taylor_ev(Val{$i}(), coeffs, j)) for i = 1:O)...),))
end
end

function (j::Jet{T, 1} where T)(x::ExplicitTangentBundle{1})
domain_check(j, x.primal)
coeffs = x.tangent.partials
ExplicitTangentBundle{1}(j[0], (jet_taylor_ev(Val{1}(), coeffs, j),))
end

function (j::Jet{T, N} where T)(x::ExplicitTangentBundle{N, M}) where {N, M}
error("TODO")
end
68 changes: 11 additions & 57 deletions src/stage1/forward.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
partial(x::TangentBundle, i) = partial(getfield(x, :tangent), i)
partial(x::ExplicitTangent, i) = getfield(getfield(x, :partials), i)
partial(x::TaylorTangent, i) = getfield(getfield(x, :coeffs), i)
partial(x::UniformTangent, i) = getfield(x, :val)
partial(x::ProductTangent, i) = ProductTangent(map(x->partial(x, i), getfield(x, :factors)))
Expand All @@ -23,22 +22,13 @@ my_frule(::ZeroBundle{1, typeof(my_frule)}, args::ATB{1}...) = nothing
(::∂☆{N})(::ZeroBundle{N, typeof(my_frule)}, ::ZeroBundle{N, ZeroBundle{1, typeof(my_frule)}}, args::ATB{N}...) where {N} = ZeroBundle{N}(nothing)

shuffle_down(b::UniformBundle{N, B, U}) where {N, B, U} =
UniformBundle{minus1(N), <:Any, U}(UniformBundle{1, B, U}(b.primal, b.tangent.val), b.tangent.val)

function shuffle_down(b::ExplicitTangentBundle{N, B}) where {N, B}
# N.B: This depends on the special properties of the canonical tangent index order
ExplicitTangentBundle{N-1}(
ExplicitTangentBundle{1}(b.primal, (partial(b, 1),)),
ntuple(2^(N-1)-1) do i
ExplicitTangentBundle{1}(partial(b, 2*i), (partial(b, 2*i+1),))
end)
end
UniformBundle{N-1, <:Any, U}(UniformBundle{1, B, U}(b.primal, b.tangent.val), b.tangent.val)

function shuffle_down(b::TaylorBundle{N, B}) where {N, B}
TaylorBundle{N-1}(
ExplicitTangentBundle{1}(b.primal, (b.tangent.coeffs[1],)),
TaylorBundle{1}(b.primal, (b.tangent.coeffs[1],)),
ntuple(N-1) do i
ExplicitTangentBundle{1}(b.tangent.coeffs[i], (b.tangent.coeffs[i+1],))
TaylorBundle{1}(b.tangent.coeffs[i], (b.tangent.coeffs[i+1],))
end)
end

Expand All @@ -54,40 +44,16 @@ end
function shuffle_up(r::CompositeBundle{1})
z₀ = primal(r.tup[1])
z₁ = partial(r.tup[1], 1)
z₂ = primal(r.tup[2])
z₁₂ = partial(r.tup[2], 1)
if z₁ == z₂
return TaylorBundle{2}(z₀, (z₁, z₁₂))
else
return ExplicitTangentBundle{2}(z₀, (z₁, z₂, z₁₂))
end
return TaylorBundle{2}(z₀, (z₁, z₁₂))
end

function taylor_compatible(a::ATB{N}, b::ATB{N}) where {N}
primal(b) === a[TaylorTangentIndex(1)] || return false
return all(1:(N-1)) do i
b[TaylorTangentIndex(i)] === a[TaylorTangentIndex(i+1)]
end
end

# Check whether the tangent bundle element is taylor-like
isswifty(::TaylorBundle) = true
isswifty(::UniformBundle) = true
isswifty(b::CompositeBundle) = all(isswifty, b.tup)
isswifty(::Any) = false

function shuffle_up(r::CompositeBundle{N}) where {N}
a, b = r.tup
if isswifty(a) && isswifty(b) && taylor_compatible(a, b)
return TaylorBundle{N+1}(primal(a),
ntuple(i->i == N+1 ?
b[TaylorTangentIndex(i-1)] : a[TaylorTangentIndex(i)],
N+1))
else
return TangentBundle{N+1}(r.tup[1].primal,
(r.tup[1].tangent.partials..., primal(b),
ntuple(i->partial(b,i), 2^(N+1)-1)...))
end
return TaylorBundle{N+1}(primal(a),
ntuple(i->i == N+1 ?
b[TaylorTangentIndex(i-1)] : a[TaylorTangentIndex(i)],
N+1))
end

function shuffle_up(r::UniformBundle{N, B, U}) where {N, B, U}
Expand Down Expand Up @@ -118,14 +84,14 @@ function (::∂☆internal{1})(args::AbstractTangentBundle{1}...)
end

function ChainRulesCore.frule_via_ad(::DiffractorRuleConfig, partials, args...)
bundles = map((p,a) -> ExplicitTangentBundle{1}(a, (p,)), partials, args)
bundles = map((p,a) -> TaylorBundle{1}(a, (p,)), partials, args)
result = ∂☆internal{1}()(bundles...)
primal(result), first_partial(result)
end

function (::∂☆shuffle{N})(args::AbstractTangentBundle{N}...) where {N}
∂☆p = ∂☆{minus1(N)}()
∂☆p(ZeroBundle{minus1(N)}(my_frule), map(shuffle_down, args)...)
∂☆p = ∂☆{N-1}()
∂☆p(ZeroBundle{N-1}(my_frule), map(shuffle_down, args)...)
end

function (::∂☆internal{N})(args::AbstractTangentBundle{N}...) where {N}
Expand All @@ -139,18 +105,6 @@ end
(::∂☆{N})(args::AbstractTangentBundle{N}...) where {N} = ∂☆internal{N}()(args...)

# Special case rules for performance
@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::TangentBundle{N}, s::AbstractTangentBundle{N}) where {N}
s = primal(s)
ExplicitTangentBundle{N}(getfield(primal(x), s),
map(x->lifted_getfield(x, s), x.tangent.partials))
end

@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::TangentBundle{N}, s::ATB{N}, inbounds::ATB{N}) where {N}
s = primal(s)
ExplicitTangentBundle{N}(getfield(primal(x), s, primal(inbounds)),
map(x->lifted_getfield(x, s), x.tangent.partials))
end

@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::TaylorBundle{N}, s::AbstractTangentBundle{N}) where {N}
s = primal(s)
TaylorBundle{N}(getfield(primal(x), s),
Expand Down
Loading