From e6ab8738e4e86db331ee1d5dd26c38e7067bfef9 Mon Sep 17 00:00:00 2001 From: Songchen Tan Date: Tue, 1 Oct 2024 18:39:08 -0400 Subject: [PATCH] Migrate to value + partials --- src/chainrules.jl | 25 +++++++------- src/derivative.jl | 14 ++++---- src/primitive.jl | 85 +++++++++++++++++++++++++++------------------- src/scalar.jl | 70 +++++++++++++++++--------------------- src/utils.jl | 7 ++-- test/derivative.jl | 6 ++-- test/primitive.jl | 12 +++---- 7 files changed, 115 insertions(+), 104 deletions(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index 4fcf08b..6e2d5e4 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -1,29 +1,30 @@ import ChainRulesCore: rrule, RuleConfig, ProjectTo, backing, @opt_out using Base.Broadcast: broadcasted -function contract(a::TaylorScalar{T, N}, b::TaylorScalar{S, N}) where {T, S, N} - mapreduce(*, +, value(a), value(b)) +function rrule(::Type{TaylorScalar{T, N}}, v::T, p::NTuple{N, T}) where {N, T} + taylor_scalar_pullback(t̄) = NoTangent(), value(t̄), partials(t̄) + return TaylorScalar{T, N}(v, p), taylor_scalar_pullback end -function rrule(::Type{TaylorScalar{T, N}}, v::NTuple{N, T}) where {N, T} - taylor_scalar_pullback(t̄) = NoTangent(), value(t̄) - return TaylorScalar(v), taylor_scalar_pullback +function rrule(::typeof(value), t::TaylorScalar{T, N}) where {N, T} + value_pullback(v̄::T) = NoTangent(), TaylorScalar{T, N}(v̄) + return value(t), value_pullback end -function rrule(::typeof(value), t::TaylorScalar{T, N}) where {N, T} - value_pullback(v̄::NTuple{N, T}) = NoTangent(), TaylorScalar(v̄) +function rrule(::typeof(partials), t::TaylorScalar{T, N}) where {N, T} + value_pullback(v̄::NTuple{N, T}) = NoTangent(), TaylorScalar(0, v̄) # for structural tangent, convert to tuple function value_pullback(v̄::Tangent{P, NTuple{N, T}}) where {P} - NoTangent(), TaylorScalar{T, N}(backing(v̄)) + NoTangent(), TaylorScalar{T, N}(zero(T), backing(v̄)) end - value_pullback(v̄) = NoTangent(), TaylorScalar{T, N}(map(x -> convert(T, x), Tuple(v̄))) - return value(t), value_pullback + value_pullback(v̄) = NoTangent(), TaylorScalar{T, N}(zero(T), map(x -> convert(T, x), Tuple(v̄))) + return partials(t), value_pullback end function rrule(::typeof(extract_derivative), t::TaylorScalar{T, N}, i::Integer) where {N, T} function extract_derivative_pullback(d̄) - NoTangent(), TaylorScalar{T, N}(ntuple(j -> j === i ? d̄ : zero(T), Val(N))), + NoTangent(), TaylorScalar{T, N}(zero(T), ntuple(j -> j === i ? d̄ : zero(T), Val(N))), NoTangent() end return extract_derivative(t, i), extract_derivative_pullback @@ -53,7 +54,7 @@ function rrule(::typeof(*), A::AbstractMatrix{S}, return A * B, gemm_pullback end -(project::ProjectTo{T})(dx::TaylorScalar{T, N}) where {N, T <: Number} = primal(dx) +(project::ProjectTo{T})(dx::TaylorScalar{T, N}) where {N, T <: Number} = value(dx) # opt-outs diff --git a/src/derivative.jl b/src/derivative.jl index 9b49b09..434ddb8 100644 --- a/src/derivative.jl +++ b/src/derivative.jl @@ -5,7 +5,7 @@ export derivative, derivative!, derivatives, make_seed derivative(f, x, l, ::Val{N}) derivative(f!, y, x, l, ::Val{N}) -Computes `order`-th directional derivative of `f` w.r.t. vector `x` in direction `l`. +Computes `N`-th directional derivative of `f` w.r.t. vector `x` in direction `l`. """ function derivative end @@ -21,7 +21,7 @@ function derivative! end derivatives(f, x, l, ::Val{N}) derivatives(f!, y, x, l, ::Val{N}) -Computes all derivatives of `f` at `x` up to order `N - 1`. +Computes all derivatives of `f` at `x` up to order `N`. """ function derivatives end @@ -32,12 +32,12 @@ function derivatives end # Convenience wrappers for converting orders to value types # and forward work to core APIs -@inline derivative(f, x, l, order::Int64) = derivative(f, x, l, Val{order + 1}()) -@inline derivative(f!, y, x, l, order::Int64) = derivative(f!, y, x, l, Val{order + 1}()) +@inline derivative(f, x, l, order::Int64) = derivative(f, x, l, Val{order}()) +@inline derivative(f!, y, x, l, order::Int64) = derivative(f!, y, x, l, Val{order}()) @inline derivative!(result, f, x, l, order::Int64) = derivative!( - result, f, x, l, Val{order + 1}()) + result, f, x, l, Val{order}()) @inline derivative!(result, f!, y, x, l, order::Int64) = derivative!( - result, f!, y, x, l, Val{order + 1}()) + result, f!, y, x, l, Val{order}()) # Core APIs @@ -69,6 +69,6 @@ end @inline function derivatives(f!, y::AbstractArray{T}, x, l, vN::Val{N}) where {T, N} buffer = similar(y, TaylorScalar{T, N}) f!(buffer, make_seed(x, l, vN)) - map!(primal, y, buffer) + map!(value, y, buffer) return buffer end diff --git a/src/primitive.jl b/src/primitive.jl index caff022..a2b02fc 100644 --- a/src/primitive.jl +++ b/src/primitive.jl @@ -8,15 +8,15 @@ import Base: hypot, max, min import Base: tail # Unary -@inline +(a::Number, b::TaylorScalar) = TaylorScalar((a + value(b)[1]), tail(value(b))...) -@inline -(a::Number, b::TaylorScalar) = TaylorScalar((a - value(b)[1]), .-tail(value(b))...) -@inline *(a::Number, b::TaylorScalar) = TaylorScalar((a .* value(b))...) +@inline +(a::Number, b::TaylorScalar) = TaylorScalar(a + value(b), partials(b)) +@inline -(a::Number, b::TaylorScalar) = TaylorScalar(a - value(b), map(-, partials(b))) +@inline *(a::Number, b::TaylorScalar) = TaylorScalar(a * value(b), a .* partials(b)) @inline /(a::Number, b::TaylorScalar) = /(promote(a, b)...) -@inline +(a::TaylorScalar, b::Number) = TaylorScalar((value(a)[1] + b), tail(value(a))...) -@inline -(a::TaylorScalar, b::Number) = TaylorScalar((value(a)[1] - b), tail(value(a))...) -@inline *(a::TaylorScalar, b::Number) = TaylorScalar((value(a) .* b)...) -@inline /(a::TaylorScalar, b::Number) = TaylorScalar((value(a) ./ b)...) +@inline +(a::TaylorScalar, b::Number) = TaylorScalar(value(a) + b, partials(a)) +@inline -(a::TaylorScalar, b::Number) = TaylorScalar(value(a) - b, partials(a)) +@inline *(a::TaylorScalar, b::Number) = TaylorScalar(value(a) * b, partials(a) .* b) +@inline /(a::TaylorScalar, b::Number) = TaylorScalar(value(a) / b, partials(a) ./ b) ## Delegated @@ -27,10 +27,10 @@ import Base: tail for func in (:exp, :expm1, :exp2, :exp10) @eval @generated function $func(t::TaylorScalar{T, N}) where {T, N} ex = quote - v = value(t) + v = flatten(t) v1 = $($(QuoteNode(func)) == :expm1 ? :(exp(v[1])) : :($$func(v[1]))) end - for i in 2:N + for i in 2:(N + 1) ex = quote $ex $(Symbol('v', i)) = +($([:($(binomial(i - 2, j - 1)) * $(Symbol('v', j)) * @@ -46,7 +46,7 @@ for func in (:exp, :expm1, :exp2, :exp10) if $(QuoteNode(func)) == :expm1 ex = :($ex; v1 = expm1(v[1])) end - ex = :($ex; TaylorScalar{T, N}(tuple($([Symbol('v', i) for i in 1:N]...)))) + ex = :($ex; TaylorScalar(tuple($([Symbol('v', i) for i in 1:(N + 1)]...)))) return :(@inbounds $ex) end end @@ -54,11 +54,11 @@ end for func in (:sin, :cos) @eval @generated function $func(t::TaylorScalar{T, N}) where {T, N} ex = quote - v = value(t) + v = flatten(t) s1 = sin(v[1]) c1 = cos(v[1]) end - for i in 2:N + for i in 2:(N + 1) ex = :($ex; $(Symbol('s', i)) = +($([:($(binomial(i - 2, j - 1)) * $(Symbol('c', j)) * @@ -69,9 +69,9 @@ for func in (:sin, :cos) v[$(i + 1 - j)]) for j in 1:(i - 1)]...))) end if $(QuoteNode(func)) == :sin - ex = :($ex; TaylorScalar($([Symbol('s', i) for i in 1:N]...))) + ex = :($ex; TaylorScalar(tuple($([Symbol('s', i) for i in 1:(N + 1)]...)))) else - ex = :($ex; TaylorScalar($([Symbol('c', i) for i in 1:N]...))) + ex = :($ex; TaylorScalar(tuple($([Symbol('c', i) for i in 1:(N + 1)]...)))) end return quote @inbounds $ex @@ -94,24 +94,27 @@ for op in [:>, :<, :(==), :(>=), :(<=)] @eval @inline $op(a::TaylorScalar, b::TaylorScalar) = $op(value(a)[1], value(b)[1]) end -@inline +(a::TaylorScalar, b::TaylorScalar) = TaylorScalar(map(+, value(a), value(b))) -@inline -(a::TaylorScalar, b::TaylorScalar) = TaylorScalar(map(-, value(a), value(b))) +@inline +(a::TaylorScalar, b::TaylorScalar) = TaylorScalar( + value(a) + value(b), map(+, partials(a), partials(b))) +@inline -(a::TaylorScalar, b::TaylorScalar) = TaylorScalar( + value(a) - value(b), map(-, partials(a), partials(b))) @generated function *(a::TaylorScalar{T, N}, b::TaylorScalar{T, N}) where {T, N} return quote - va, vb = value(a), value(b) - @inbounds TaylorScalar($([:(+($([:($(binomial(i - 1, j - 1)) * va[$j] * - vb[$(i + 1 - j)]) for j in 1:i]...))) - for i in 1:N]...)) + va, vb = flatten(a), flatten(b) + r = tuple($([:(+($([:($(binomial(i - 1, j - 1)) * va[$j] * + vb[$(i + 1 - j)]) for j in 1:i]...))) + for i in 1:(N + 1)]...)) + @inbounds TaylorScalar(r[1], r[2:end]) end end @generated function /(a::TaylorScalar{T, N}, b::TaylorScalar{T, N}) where {T, N} ex = quote - va, vb = value(a), value(b) + va, vb = flatten(a), flatten(b) v1 = va[1] / vb[1] end - for i in 2:N + for i in 2:(N + 1) ex = quote $ex $(Symbol('v', i)) = (va[$i] - @@ -120,24 +123,28 @@ end for j in 1:(i - 1)]...))) / vb[1] end end - ex = :($ex; TaylorScalar($([Symbol('v', i) for i in 1:N]...))) + ex = quote + $ex + v = tuple($([Symbol('v', i) for i in 1:(N + 1)]...)) + TaylorScalar(v) + end return :(@inbounds $ex) end for R in (Integer, Real) @eval @generated function ^(t::TaylorScalar{T, N}, n::S) where {S <: $R, T, N} ex = quote - v = value(t) + v = flatten(t) w11 = 1 u1 = ^(v[1], n) end - for k in 1:N + for k in 1:(N + 1) ex = quote $ex $(Symbol('p', k)) = ^(v[1], n - $(k - 1)) end end - for i in 2:N + for i in 2:(N + 1) subex = quote $(Symbol('w', i, 1)) = 0 end @@ -158,7 +165,11 @@ for R in (Integer, Real) for k in 2:i]...)) end end - ex = :($ex; TaylorScalar($([Symbol('u', i) for i in 1:N]...))) + ex = quote + $ex + v = tuple($([Symbol('u', i) for i in 1:(N + 1)]...)) + TaylorScalar(v) + end return :(@inbounds $ex) end @eval function ^(a::S, t::TaylorScalar{T, N}) where {S <: $R, T, N} @@ -172,11 +183,11 @@ end t::TaylorScalar{T, N}) where {T, M, N} # M + 1 == N return quote $(Expr(:meta, :inline)) - vdf, vt = value(df), value(t) - @inbounds TaylorScalar(f, - $([:(+($([:($(binomial(i - 1, j - 1)) * vdf[$j] * - vt[$(i + 2 - j)]) for j in 1:i]...))) - for i in 1:M]...)) + vdf, vt = flatten(df), flatten(t) + partials = tuple($([:(+($([:($(binomial(i - 1, j - 1)) * vdf[$j] * + vt[$(i + 2 - j)]) for j in 1:i]...))) + for i in 1:(M + 1)]...)) + @inbounds TaylorScalar(f, partials) end end @@ -185,10 +196,10 @@ raise(::T, df::S, t::TaylorScalar{T, N}) where {S <: Number, T, N} = df * t @generated function raiseinv(f::T, df::TaylorScalar{T, M}, t::TaylorScalar{T, N}) where {T, M, N} # M + 1 == N ex = quote - vdf, vt = value(df), value(t) + vdf, vt = flatten(df), flatten(t) v1 = vt[2] / vdf[1] end - for i in 2:M + for i in 2:(M + 1) ex = quote $ex $(Symbol('v', i)) = (vt[$(i + 1)] - @@ -197,6 +208,10 @@ raise(::T, df::S, t::TaylorScalar{T, N}) where {S <: Number, T, N} = df * t for j in 1:(i - 1)]...))) / vdf[1] end end - ex = :($ex; TaylorScalar(f, $([Symbol('v', i) for i in 1:M]...))) + ex = quote + $ex + v = tuple($([Symbol('v', i) for i in 1:(M + 1)]...)) + TaylorScalar(f, v) + end return :(@inbounds $ex) end diff --git a/src/scalar.jl b/src/scalar.jl index ca523fb..0fe158a 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -1,5 +1,3 @@ -import Base: zero, one, adjoint, conj, transpose -import Base: +, -, *, / import Base: convert, promote_rule export TaylorScalar @@ -25,55 +23,52 @@ Representation of Taylor polynomials. # Fields -- `value::NTuple{N, T}`: i-th element of this stores the (i-1)-th derivative +- `value::T`: zeroth order coefficient +- `partials::NTuple{N, T}`: i-th element of this stores the i-th derivative """ struct TaylorScalar{T, N} <: Real - value::NTuple{N, T} - function TaylorScalar{T, N}(value::NTuple{N, T}) where {T, N} + value::T + partials::NTuple{N, T} + function TaylorScalar{T, N}(value::T, partials::NTuple{N, T}) where {T, N} can_taylorize(T) || throw_cannot_taylorize(T) - new{T, N}(value) + new{T, N}(value, partials) end end -TaylorScalar(value::NTuple{N, T}) where {T, N} = TaylorScalar{T, N}(value) -TaylorScalar(value::Vararg{T, N}) where {T, N} = TaylorScalar{T, N}(value) +function TaylorScalar(value::T, partials::NTuple{N, T}) where {T, N} + TaylorScalar{T, N}(value, partials) +end + +function TaylorScalar(value_and_partials::NTuple{N, T}) where {T, N} + TaylorScalar(value_and_partials[1], value_and_partials[2:end]) +end """ TaylorScalar{T, N}(x::T) where {T, N} Construct a Taylor polynomial with zeroth order coefficient. """ -@generated function TaylorScalar{T, N}(x::S) where {T, S <: Real, N} - return quote - $(Expr(:meta, :inline)) - TaylorScalar((T(x), $(zeros(T, N - 1)...))) - end -end +TaylorScalar{T, N}(x::S) where {T, S <: Real, N} = TaylorScalar( + T(x), ntuple(i -> zero(T), Val(N))) """ TaylorScalar{T, N}(x::T, d::T) where {T, N} Construct a Taylor polynomial with zeroth and first order coefficient, acting as a seed. """ -@generated function TaylorScalar{T, N}(x::S, d::S) where {T, S <: Real, N} - return quote - $(Expr(:meta, :inline)) - TaylorScalar((T(x), T(d), $(zeros(T, N - 2)...))) - end -end - -@generated function TaylorScalar{T, N}(t::TaylorScalar{T, M}) where {T, N, M} - N <= M ? quote - $(Expr(:meta, :inline)) - TaylorScalar(value(t)[1:N]) - end : quote - $(Expr(:meta, :inline)) - TaylorScalar((value(t)..., $(zeros(T, N - M)...))) - end +TaylorScalar{T, N}(x::S, d::S) where {T, S <: Real, N} = TaylorScalar( + T(x), ntuple(i -> i == 1 ? T(d) : zero(T), Val(N))) + +function TaylorScalar{T, N}(t::TaylorScalar{T, M}) where {T, N, M} + v = value(t) + p = partials(t) + N <= M ? TaylorScalar(v, p[1:N]) : + TaylorScalar(v, ntuple(i -> i <= M ? p[i] : zero(T), Val(N))) end @inline value(t::TaylorScalar) = t.value -@inline extract_derivative(t::TaylorScalar, i::Integer) = t.value[i] +@inline partials(t::TaylorScalar) = t.partials +@inline extract_derivative(t::TaylorScalar, i::Integer) = t.partials[i] @inline function extract_derivative(v::AbstractArray{T}, i::Integer) where {T <: TaylorScalar} map(t -> extract_derivative(t, i), v) @@ -83,7 +78,8 @@ end i::Integer) where {T <: TaylorScalar} map!(t -> extract_derivative(t, i), result, v) end -@inline primal(t::TaylorScalar) = extract_derivative(t, 1) + +@inline flatten(t::TaylorScalar) = (value(t), partials(t)...) function promote_rule(::Type{TaylorScalar{T, N}}, ::Type{S}) where {T, S, N} @@ -91,20 +87,18 @@ function promote_rule(::Type{TaylorScalar{T, N}}, end function (::Type{F})(x::TaylorScalar{T, N}) where {T, N, F <: AbstractFloat} - F(primal(x)) + F(value(x)) end -function Base.nextfloat(x::TaylorScalar{T, N}) where {T, N} - TaylorScalar{T, N}(ntuple(i -> i == 1 ? nextfloat(value(x)[i]) : value(x)[i], N)) -end +const COVARIANT_OPS = Symbol[:nextfloat, :prevfloat] -function Base.prevfloat(x::TaylorScalar{T, N}) where {T, N} - TaylorScalar{T, N}(ntuple(i -> i == 1 ? prevfloat(value(x)[i]) : value(x)[i], N)) +for op in COVARIANT_OPS + @eval Base.$(op)(x::TaylorScalar{T, N}) where {T, N} = TaylorScalar($(op)(value(x)), partials(x)) end const UNARY_PREDICATES = Symbol[ :isinf, :isnan, :isfinite, :iseven, :isodd, :isreal, :isinteger] for pred in UNARY_PREDICATES - @eval Base.$(pred)(x::TaylorScalar) = $(pred)(primal(x)) + @eval Base.$(pred)(x::TaylorScalar) = $(pred)(value(x)) end diff --git a/src/utils.jl b/src/utils.jl index 08cce53..4cc9c23 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -10,10 +10,11 @@ dummy = (NoTangent(), 1) function define_unary_function(func, m) F = typeof(func) # base case - @eval m function (op::$F)(t::TaylorScalar{T, 2}) where {T} - t0, t1 = value(t) + @eval m function (op::$F)(t::TaylorScalar{T, 1}) where {T} + t0 = value(t) + t1 = first(partials(t)) f0, f1 = frule((NoTangent(), t1), op, t0) - TaylorScalar{T, 2}(f0, zero_tangent(f0) + f1) + TaylorScalar{T, 1}(f0, zero_tangent(f0) + f1) end der = frule(dummy, func, z)[2] term, raiser = der isa Pow && der.exp == -1 ? (der.base, raiseinv) : (der, raise) diff --git a/test/derivative.jl b/test/derivative.jl index 80d8434..3c00ff6 100644 --- a/test/derivative.jl +++ b/test/derivative.jl @@ -20,12 +20,12 @@ end end x = 2.0 y = [0.0, 0.0] - @test derivative(g!, y, x, 1.0, Val{2}()) ≈ [4.0, 1.0] + @test derivative(g!, y, x, 1.0, Val{1}()) ≈ [4.0, 1.0] end @testset "O-function, I-derivative" begin g(x) = x .^ 2 - @test derivative!(zeros(2), g, [1.0, 2.0], [1.0, 0.0], Val{2}()) ≈ [2.0, 0.0] + @test derivative!(zeros(2), g, [1.0, 2.0], [1.0, 0.0], Val{1}()) ≈ [2.0, 0.0] end @testset "I-function, I-derivative" begin @@ -35,5 +35,5 @@ end end x = [2.0, 3.0] y = [0.0, 0.0] - @test derivative!(y, g!, zeros(2), x, [1.0, 0.0], Val{2}()) ≈ [4.0, 0.0] + @test derivative!(y, g!, zeros(2), x, [1.0, 0.0], Val{1}()) ≈ [4.0, 0.0] end diff --git a/test/primitive.jl b/test/primitive.jl index 974e845..6313a61 100644 --- a/test/primitive.jl +++ b/test/primitive.jl @@ -50,12 +50,12 @@ end @testset "Corner cases" begin offenders = ( - TaylorDiff.TaylorScalar{Float64, 4}((Inf, 1.0, 0.0, 0.0)), - TaylorDiff.TaylorScalar{Float64, 4}((Inf, 0.0, 0.0, 0.0)), - TaylorDiff.TaylorScalar{Float64, 4}((1.0, 0.0, 0.0, 0.0)), - TaylorDiff.TaylorScalar{Float64, 4}((1.0, Inf, 0.0, 0.0)), - TaylorDiff.TaylorScalar{Float64, 4}((0.0, 1.0, 0.0, 0.0)), - TaylorDiff.TaylorScalar{Float64, 4}((0.0, Inf, 0.0, 0.0)) # Others ? + TaylorDiff.TaylorScalar{Float64, 3}(Inf, (1.0, 0.0, 0.0)), + TaylorDiff.TaylorScalar{Float64, 3}(Inf, (0.0, 0.0, 0.0)), + TaylorDiff.TaylorScalar{Float64, 3}(1.0, (0.0, 0.0, 0.0)), + TaylorDiff.TaylorScalar{Float64, 3}(1.0, (Inf, 0.0, 0.0)), + TaylorDiff.TaylorScalar{Float64, 3}(0.0, (1.0, 0.0, 0.0)), + TaylorDiff.TaylorScalar{Float64, 3}(0.0, (Inf, 0.0, 0.0)) # Others ? ) f_id = ( :id => x -> x,