Skip to content

Commit

Permalink
Migrate to value + partials
Browse files Browse the repository at this point in the history
  • Loading branch information
tansongchen committed Oct 1, 2024
1 parent e06c8fd commit e6ab873
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 104 deletions.
25 changes: 13 additions & 12 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
@@ -1,29 +1,30 @@
import ChainRulesCore: rrule, RuleConfig, ProjectTo, backing, @opt_out
using Base.Broadcast: broadcasted

function contract(a::TaylorScalar{T, N}, b::TaylorScalar{S, N}) where {T, S, N}
mapreduce(*, +, value(a), value(b))
function rrule(::Type{TaylorScalar{T, N}}, v::T, p::NTuple{N, T}) where {N, T}
taylor_scalar_pullback(t̄) = NoTangent(), value(t̄), partials(t̄)
return TaylorScalar{T, N}(v, p), taylor_scalar_pullback
end

function rrule(::Type{TaylorScalar{T, N}}, v::NTuple{N, T}) where {N, T}
taylor_scalar_pullback(t̄) = NoTangent(), value(t̄)
return TaylorScalar(v), taylor_scalar_pullback
function rrule(::typeof(value), t::TaylorScalar{T, N}) where {N, T}
value_pullback(v̄::T) = NoTangent(), TaylorScalar{T, N}(v̄)
return value(t), value_pullback
end

function rrule(::typeof(value), t::TaylorScalar{T, N}) where {N, T}
value_pullback(v̄::NTuple{N, T}) = NoTangent(), TaylorScalar(v̄)
function rrule(::typeof(partials), t::TaylorScalar{T, N}) where {N, T}
value_pullback(v̄::NTuple{N, T}) = NoTangent(), TaylorScalar(0, v̄)
# for structural tangent, convert to tuple
function value_pullback(v̄::Tangent{P, NTuple{N, T}}) where {P}
NoTangent(), TaylorScalar{T, N}(backing(v̄))
NoTangent(), TaylorScalar{T, N}(zero(T), backing(v̄))
end
value_pullback(v̄) = NoTangent(), TaylorScalar{T, N}(map(x -> convert(T, x), Tuple(v̄)))
return value(t), value_pullback
value_pullback(v̄) = NoTangent(), TaylorScalar{T, N}(zero(T), map(x -> convert(T, x), Tuple(v̄)))
return partials(t), value_pullback
end

function rrule(::typeof(extract_derivative), t::TaylorScalar{T, N},
i::Integer) where {N, T}
function extract_derivative_pullback(d̄)
NoTangent(), TaylorScalar{T, N}(ntuple(j -> j === i ?: zero(T), Val(N))),
NoTangent(), TaylorScalar{T, N}(zero(T), ntuple(j -> j === i ?: zero(T), Val(N))),
NoTangent()
end
return extract_derivative(t, i), extract_derivative_pullback
Expand Down Expand Up @@ -53,7 +54,7 @@ function rrule(::typeof(*), A::AbstractMatrix{S},
return A * B, gemm_pullback
end

(project::ProjectTo{T})(dx::TaylorScalar{T, N}) where {N, T <: Number} = primal(dx)
(project::ProjectTo{T})(dx::TaylorScalar{T, N}) where {N, T <: Number} = value(dx)

# opt-outs

Expand Down
14 changes: 7 additions & 7 deletions src/derivative.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ export derivative, derivative!, derivatives, make_seed
derivative(f, x, l, ::Val{N})
derivative(f!, y, x, l, ::Val{N})
Computes `order`-th directional derivative of `f` w.r.t. vector `x` in direction `l`.
Computes `N`-th directional derivative of `f` w.r.t. vector `x` in direction `l`.
"""
function derivative end

Expand All @@ -21,7 +21,7 @@ function derivative! end
derivatives(f, x, l, ::Val{N})
derivatives(f!, y, x, l, ::Val{N})
Computes all derivatives of `f` at `x` up to order `N - 1`.
Computes all derivatives of `f` at `x` up to order `N`.
"""
function derivatives end

Expand All @@ -32,12 +32,12 @@ function derivatives end
# Convenience wrappers for converting orders to value types
# and forward work to core APIs

@inline derivative(f, x, l, order::Int64) = derivative(f, x, l, Val{order + 1}())
@inline derivative(f!, y, x, l, order::Int64) = derivative(f!, y, x, l, Val{order + 1}())
@inline derivative(f, x, l, order::Int64) = derivative(f, x, l, Val{order}())
@inline derivative(f!, y, x, l, order::Int64) = derivative(f!, y, x, l, Val{order}())
@inline derivative!(result, f, x, l, order::Int64) = derivative!(
result, f, x, l, Val{order + 1}())
result, f, x, l, Val{order}())
@inline derivative!(result, f!, y, x, l, order::Int64) = derivative!(
result, f!, y, x, l, Val{order + 1}())
result, f!, y, x, l, Val{order}())

# Core APIs

Expand Down Expand Up @@ -69,6 +69,6 @@ end
@inline function derivatives(f!, y::AbstractArray{T}, x, l, vN::Val{N}) where {T, N}
buffer = similar(y, TaylorScalar{T, N})
f!(buffer, make_seed(x, l, vN))
map!(primal, y, buffer)
map!(value, y, buffer)
return buffer
end
85 changes: 50 additions & 35 deletions src/primitive.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ import Base: hypot, max, min
import Base: tail

# Unary
@inline +(a::Number, b::TaylorScalar) = TaylorScalar((a + value(b)[1]), tail(value(b))...)
@inline -(a::Number, b::TaylorScalar) = TaylorScalar((a - value(b)[1]), .-tail(value(b))...)
@inline *(a::Number, b::TaylorScalar) = TaylorScalar((a .* value(b))...)
@inline +(a::Number, b::TaylorScalar) = TaylorScalar(a + value(b), partials(b))
@inline -(a::Number, b::TaylorScalar) = TaylorScalar(a - value(b), map(-, partials(b)))
@inline *(a::Number, b::TaylorScalar) = TaylorScalar(a * value(b), a .* partials(b))
@inline /(a::Number, b::TaylorScalar) = /(promote(a, b)...)

@inline +(a::TaylorScalar, b::Number) = TaylorScalar((value(a)[1] + b), tail(value(a))...)
@inline -(a::TaylorScalar, b::Number) = TaylorScalar((value(a)[1] - b), tail(value(a))...)
@inline *(a::TaylorScalar, b::Number) = TaylorScalar((value(a) .* b)...)
@inline /(a::TaylorScalar, b::Number) = TaylorScalar((value(a) ./ b)...)
@inline +(a::TaylorScalar, b::Number) = TaylorScalar(value(a) + b, partials(a))
@inline -(a::TaylorScalar, b::Number) = TaylorScalar(value(a) - b, partials(a))
@inline *(a::TaylorScalar, b::Number) = TaylorScalar(value(a) * b, partials(a) .* b)
@inline /(a::TaylorScalar, b::Number) = TaylorScalar(value(a) / b, partials(a) ./ b)

## Delegated

Expand All @@ -27,10 +27,10 @@ import Base: tail
for func in (:exp, :expm1, :exp2, :exp10)
@eval @generated function $func(t::TaylorScalar{T, N}) where {T, N}
ex = quote
v = value(t)
v = flatten(t)
v1 = $($(QuoteNode(func)) == :expm1 ? :(exp(v[1])) : :($$func(v[1])))
end
for i in 2:N
for i in 2:(N + 1)
ex = quote
$ex
$(Symbol('v', i)) = +($([:($(binomial(i - 2, j - 1)) * $(Symbol('v', j)) *
Expand All @@ -46,19 +46,19 @@ for func in (:exp, :expm1, :exp2, :exp10)
if $(QuoteNode(func)) == :expm1
ex = :($ex; v1 = expm1(v[1]))
end
ex = :($ex; TaylorScalar{T, N}(tuple($([Symbol('v', i) for i in 1:N]...))))
ex = :($ex; TaylorScalar(tuple($([Symbol('v', i) for i in 1:(N + 1)]...))))
return :(@inbounds $ex)
end
end

for func in (:sin, :cos)
@eval @generated function $func(t::TaylorScalar{T, N}) where {T, N}
ex = quote
v = value(t)
v = flatten(t)
s1 = sin(v[1])
c1 = cos(v[1])
end
for i in 2:N
for i in 2:(N + 1)
ex = :($ex;
$(Symbol('s', i)) = +($([:($(binomial(i - 2, j - 1)) *
$(Symbol('c', j)) *
Expand All @@ -69,9 +69,9 @@ for func in (:sin, :cos)
v[$(i + 1 - j)]) for j in 1:(i - 1)]...)))
end
if $(QuoteNode(func)) == :sin
ex = :($ex; TaylorScalar($([Symbol('s', i) for i in 1:N]...)))
ex = :($ex; TaylorScalar(tuple($([Symbol('s', i) for i in 1:(N + 1)]...))))
else
ex = :($ex; TaylorScalar($([Symbol('c', i) for i in 1:N]...)))
ex = :($ex; TaylorScalar(tuple($([Symbol('c', i) for i in 1:(N + 1)]...))))
end
return quote
@inbounds $ex
Expand All @@ -94,24 +94,27 @@ for op in [:>, :<, :(==), :(>=), :(<=)]
@eval @inline $op(a::TaylorScalar, b::TaylorScalar) = $op(value(a)[1], value(b)[1])
end

@inline +(a::TaylorScalar, b::TaylorScalar) = TaylorScalar(map(+, value(a), value(b)))
@inline -(a::TaylorScalar, b::TaylorScalar) = TaylorScalar(map(-, value(a), value(b)))
@inline +(a::TaylorScalar, b::TaylorScalar) = TaylorScalar(
value(a) + value(b), map(+, partials(a), partials(b)))
@inline -(a::TaylorScalar, b::TaylorScalar) = TaylorScalar(
value(a) - value(b), map(-, partials(a), partials(b)))

@generated function *(a::TaylorScalar{T, N}, b::TaylorScalar{T, N}) where {T, N}
return quote
va, vb = value(a), value(b)
@inbounds TaylorScalar($([:(+($([:($(binomial(i - 1, j - 1)) * va[$j] *
vb[$(i + 1 - j)]) for j in 1:i]...)))
for i in 1:N]...))
va, vb = flatten(a), flatten(b)
r = tuple($([:(+($([:($(binomial(i - 1, j - 1)) * va[$j] *
vb[$(i + 1 - j)]) for j in 1:i]...)))
for i in 1:(N + 1)]...))
@inbounds TaylorScalar(r[1], r[2:end])
end
end

@generated function /(a::TaylorScalar{T, N}, b::TaylorScalar{T, N}) where {T, N}
ex = quote
va, vb = value(a), value(b)
va, vb = flatten(a), flatten(b)
v1 = va[1] / vb[1]
end
for i in 2:N
for i in 2:(N + 1)
ex = quote
$ex
$(Symbol('v', i)) = (va[$i] -
Expand All @@ -120,24 +123,28 @@ end
for j in 1:(i - 1)]...))) / vb[1]
end
end
ex = :($ex; TaylorScalar($([Symbol('v', i) for i in 1:N]...)))
ex = quote
$ex
v = tuple($([Symbol('v', i) for i in 1:(N + 1)]...))
TaylorScalar(v)
end
return :(@inbounds $ex)
end

for R in (Integer, Real)
@eval @generated function ^(t::TaylorScalar{T, N}, n::S) where {S <: $R, T, N}
ex = quote
v = value(t)
v = flatten(t)
w11 = 1
u1 = ^(v[1], n)
end
for k in 1:N
for k in 1:(N + 1)
ex = quote
$ex
$(Symbol('p', k)) = ^(v[1], n - $(k - 1))
end
end
for i in 2:N
for i in 2:(N + 1)
subex = quote
$(Symbol('w', i, 1)) = 0
end
Expand All @@ -158,7 +165,11 @@ for R in (Integer, Real)
for k in 2:i]...))
end
end
ex = :($ex; TaylorScalar($([Symbol('u', i) for i in 1:N]...)))
ex = quote
$ex
v = tuple($([Symbol('u', i) for i in 1:(N + 1)]...))
TaylorScalar(v)
end
return :(@inbounds $ex)
end
@eval function ^(a::S, t::TaylorScalar{T, N}) where {S <: $R, T, N}
Expand All @@ -172,11 +183,11 @@ end
t::TaylorScalar{T, N}) where {T, M, N} # M + 1 == N
return quote
$(Expr(:meta, :inline))
vdf, vt = value(df), value(t)
@inbounds TaylorScalar(f,
$([:(+($([:($(binomial(i - 1, j - 1)) * vdf[$j] *
vt[$(i + 2 - j)]) for j in 1:i]...)))
for i in 1:M]...))
vdf, vt = flatten(df), flatten(t)
partials = tuple($([:(+($([:($(binomial(i - 1, j - 1)) * vdf[$j] *
vt[$(i + 2 - j)]) for j in 1:i]...)))
for i in 1:(M + 1)]...))
@inbounds TaylorScalar(f, partials)
end
end

Expand All @@ -185,10 +196,10 @@ raise(::T, df::S, t::TaylorScalar{T, N}) where {S <: Number, T, N} = df * t
@generated function raiseinv(f::T, df::TaylorScalar{T, M},
t::TaylorScalar{T, N}) where {T, M, N} # M + 1 == N
ex = quote
vdf, vt = value(df), value(t)
vdf, vt = flatten(df), flatten(t)
v1 = vt[2] / vdf[1]
end
for i in 2:M
for i in 2:(M + 1)
ex = quote
$ex
$(Symbol('v', i)) = (vt[$(i + 1)] -
Expand All @@ -197,6 +208,10 @@ raise(::T, df::S, t::TaylorScalar{T, N}) where {S <: Number, T, N} = df * t
for j in 1:(i - 1)]...))) / vdf[1]
end
end
ex = :($ex; TaylorScalar(f, $([Symbol('v', i) for i in 1:M]...)))
ex = quote
$ex
v = tuple($([Symbol('v', i) for i in 1:(M + 1)]...))
TaylorScalar(f, v)
end
return :(@inbounds $ex)
end
Loading

0 comments on commit e6ab873

Please sign in to comment.