Skip to content

Commit

Permalink
Add theory
Browse files Browse the repository at this point in the history
  • Loading branch information
tansongchen committed Oct 14, 2024
1 parent 9ef7cdb commit b4c8780
Show file tree
Hide file tree
Showing 9 changed files with 201 additions and 96 deletions.
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ makedocs(;
assets = String[]),
pages = [
"Home" => "index.md",
"Theory" => "theory.md",
"API" => "api.md"
])

Expand Down
106 changes: 106 additions & 0 deletions docs/src/theory.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
```@meta
CurrentModule = TaylorDiff
```

# Theory

TaylorDiff.jl is an operator-overloading based forward-mode automatic differentiation (AD) package. "Forward-mode" implies that the basic capability of this package is that, for function $f:\mathbb R^n\to\mathbb R^m$, place to evaluate derivative $x\in\mathbb R^n$ and direction $l\in\mathbb R^n$, we compute
$$
f(x),\partial f(x)\times v,\partial^2f(x)\times v\times v,\cdots,\partial^pf(x)\times v\times\cdots\times v
$$
i.e., the function value and the directional derivative up to order $p$. This notation might be unfamiliar to Julia users that had experience with other AD packages, but $\partial f(x)$ is simply the jacobian $J$, and $\partial f(x)\times v$ is simply the Jacobian-vector product (jvp). In other words, this is a simple generalization of Jacobian-vector product to Hessian-vector-vector product, and to even higher orders.

The main advantage of doing this instead of doing $p$ first-order Jacobian-vector products is that nesting first-order AD results in expential scaling w.r.t $p$, while this method, also known as Taylor mode, should be (almost) linear scaling w.r.t $p$. We will see the reason of this claim later.

In order to achieve this, assuming that $f$ is a nested function $f_k\circ\cdots\circ f_2\circ f_1$, where each $f_i$ is a basic and simple function, or called "primitives". We need to figure out how to propagate the derivatives through each step. In first order AD, this is achieved by the "dual" pair $x_0+x_1\varepsilon$, where $\varepsilon^2=0$, and for each primitive we make a method overload
$$
f(x_0+x_1\varepsilon)=f(x_0)+\partial f(x_0) x_1\varepsilon
$$
Similarly in higher-order AD, we need for each primitive a method overload for a truncated Taylor polynomial up to order $p$, and in this polynomial we will use $t$ instead of $\varepsilon$ to denote the sensitivity. "Truncated" means $t^{p+1}=0$, similar as what we defined for dual numbers. So
$$
f(x_0+x_1t+x_2t^2+\cdots+x_pt^p)=?
$$
What is the math expression that we should put into the question mark? That specific expression is called the "pushforward rule", and we will talk about how to derive the pushforward rule below.

## Arithmetic of polynomials

Before deriving pushforward rules, let's first introduce several basic properties of polynomials.

If $x(t)$ and $y(t)$ are both truncated Taylor polynomials, i.e.
$$
\begin{aligned}
x&=x_0+x_1t+\cdots+x_pt^p\\
y&=y_0+y_1t+\cdots+y_pt^p
\end{aligned}
$$
Then it's obvious that the polynomial addition and subtraction should be
$$
(x\pm y)_k=x_k\pm y_k
$$
And with some derivation we can also get the polynomial multiplication rule
$$
(x\times y)_k=\sum_{i=0}^kx_iy_{k-i}
$$
The polynomial division rule is less obvious, but if $x/y=z$, then equivalently $x=yz$, i.e.
$$
\left(\sum_{i=0}^py_it^i\right)\left(\sum_{i=0}^pz_it^i\right)=\sum_{i=0}^px_it^i
$$
if we relate the coefficient of $t^k$ on both sides we get
$$
\sum_{i=0}^k z_iy_{k-i}=x_k
$$
so, equivalently,
$$
z_k=\frac1{y_0}\left(x_k-\sum_{i=0}^{k-1}z_iy_{k-1}\right)
$$
This is a recurrence relation, which means that we can first get $z_0=x_0/y_0$, and then get $z_1$ using $z_0$, and then get $z_2$ using $z_0,z_1$ etc.

## Pushforward rule for elementary functions

Let's now consider how to derive the pushforward rule for elementary functions. We will use $\exp$ and $\log$ as two examples.

If $x(t)$ is a polynomial and we want to get $e(t)=\exp(x(t))$, we can actually get that by formulating an ordinary differential equation:
$$
e'(t)=\exp(x(t))x'(t);\quad e_0=\exp(x_0)
$$
If we expand both $e$ and $x$ in the equation, we will get
$$
\sum_{i=1}^pie_it^{i-1}=\left(\sum_{i=0}^{p-1} e_it^i\right)\left(\sum_{i=1}^pix_it^{i-1}\right)
$$
relating the coefficient of $t^{k-1}$ on both sides, we get
$$
ke_k=\sum_{i=0}^{k-1}e_i\times (k-i)x_{k-i}
$$
This is, again, a recurrence relation, so we can get $e_1,\cdots,e_p$ step-by-step.

If $x(t)$ is a polynomial and we want to get $l(t)=\log(x(t))$, we can actually get that by formulating an ordinary differential equation:
$$
l'(t)=\frac1xx'(t);\quad l_0=\log(x_0)
$$
If we expand both $l$ and $x$ in the equation, the RHS is simply polynomial divisions, and we get
$$
l_k=\frac1{x_0}\left(x_k-\frac1k\sum_{i=1}^{k-1}il_ix_{k-j}\right)
$$

---

Now notice the difference between the rule for $\exp$ and $\log$: the derivative of exponentiation is itself, so we can obtain from recurrence relation; the derivative of logarithm is $1/x$, an algebraic expression in $x$, so it can be directly computed. Similarly, we have $(\tan x)'=1+\tan^2x$ but $(\arctan x)'=(1+x^2)^{-1}$. We summarize (omitting proof) that

- Every $\exp$-like function (like $\sin$, $\cos$, $\tan$, $\sinh$, ...)'s derivative is somehow recursive
- Every $\log$-like function (like $\arcsin$, $\arccos$, $\arctan$, $\operatorname{arcsinh}$, ...)'s derivative is algebraic

So all of the elementary functions have an easy pushforward rule that can be computed within $O(p^2)$ time. Note that this is an elegant and straightforward corollary from the definition of "elementary function" in differential algebra.

## Generic pushforward rule

For a generic $f(x)$, if we don't bother deriving the specific recurrence rule for it, we can still automatically generate pushforward rule in the following manner. Let's denote the derivative of $f$ w.r.t $x$ to be $d(x)$, then for $f(t)=f(x(t))$ we have
$$
f'(t)=d(x(t))x'(t);\quad f(0)=f(x_0)
$$
when we expand $f$ and $x$ up to order $p$ into this equation, we notice that only order $p-1$ is needed for $d(x(t))$. In other words, we turn a problem of finding $p$-th order pushforward for $f$, to a problem of finding $p-1$-th order pushforward for $d$, and we can recurse down to the first order. The first-order derivative expressions are captured from ChainRules.jl, which made this process fully automatic.

This strategy is in principle equivalent to nesting first-order differentiation, which could potentially leads to exponential scaling; however, in practice there is a huge difference. This generation of pushforward rule happens at **compile time**, which gives the compiler a chance to check redundant expressions and optimize it down to quadratic time. Compiler has stack limits but this should work for at least up to order 100.

In the current implementation of TaylorDiff.jl, all $\log$-like functions' pushforward rules are generated by this strategy, since their derivatives are simple algebraic expressions; some $\exp$-like functions, like sinh, is also generated; the most-often-used several $\exp$-like functions are hand-written with hand-derived recurrence relations.

If you find that the code generated by this strategy is slow, please file an issue and we will look into it.
57 changes: 19 additions & 38 deletions src/derivative.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,26 @@
export derivative, derivative!, derivatives

export derivative, derivative!, derivatives, make_seed
# Added to help Zygote infer types
@inline make_seed(x::T, l::T, ::Val{P}) where {T <: Real, P} = TaylorScalar{P}(x, l)
@inline make_seed(x::A, l::A, ::Val{P}) where {A <: AbstractArray, P} = broadcast(
make_seed, x, l, Val{P}())

"""
derivative(f, x, ::Val{P})
derivative(f, x, l, ::Val{P})
derivative(f!, y, x, l, ::Val{P})
Computes `P`-th directional derivative of `f` w.r.t. vector `x` in direction `l`.
Computes `P`-th directional derivative of `f` w.r.t. vector `x` in direction `l`. If `x` is a Number, the direction `l` can be omitted.
"""
function derivative end

@inline derivative(f, x::Number, p::Val{P}) where {P} = extract_derivative(
derivatives(f, x, one(x), p), p)
@inline derivative(f, x, l, p::Val{P}) where {P} = extract_derivative(
derivatives(f, x, l, p), p)
@inline derivative(f!, y, x, l, p::Val{P}) where {P} = extract_derivative(
derivatives(f!, y, x, l, p), p)

"""
derivative!(result, f, x, l, ::Val{P})
derivative!(result, f!, y, x, l, ::Val{P})
Expand All @@ -17,6 +29,11 @@ In-place derivative calculation APIs. `result` is expected to be pre-allocated a
"""
function derivative! end

@inline derivative!(result, f, x, l, p::Val{P}) where {P} = extract_derivative!(
result, derivatives(f, x, l, p), p)
@inline derivative!(result, f!, y, x, l, p::Val{P}) where {P} = extract_derivative!(
result, derivatives(f!, y, x, l, p), p)

"""
derivatives(f, x, l, ::Val{P})
derivatives(f!, y, x, l, ::Val{P})
Expand All @@ -25,43 +42,7 @@ Computes all derivatives of `f` at `x` up to order `P`.
"""
function derivatives end

# Convenience wrapper for adding unit seed to the input

@inline derivative(f, x, p::Int64) = derivative(f, x, broadcast(one, x), p)

# Convenience wrappers for converting ps to value types
# and forward work to core APIs

@inline derivative(f, x, l, p::Int64) = derivative(f, x, l, Val{p}())
@inline derivative(f!, y, x, l, p::Int64) = derivative(f!, y, x, l, Val{p}())
@inline derivative!(result, f, x, l, p::Int64) = derivative!(
result, f, x, l, Val{p}())
@inline derivative!(result, f!, y, x, l, p::Int64) = derivative!(
result, f!, y, x, l, Val{p}())

# Core APIs

# Added to help Zygote infer types
@inline make_seed(x::T, l::T, ::Val{P}) where {T <: Real, P} = TaylorScalar{P}(x, l)
@inline make_seed(x::A, l::A, ::Val{P}) where {A <: AbstractArray, P} = broadcast(
make_seed, x, l, Val{P}())

# `derivative` API: computes the `P - 1`-th derivative of `f` at `x`
@inline derivative(f, x, l, p::Val{P}) where {P} = extract_derivative(
derivatives(f, x, l, p), p)
@inline derivative(f!, y, x, l, p::Val{P}) where {P} = extract_derivative(
derivatives(f!, y, x, l, p), p)
@inline derivative!(result, f, x, l, p::Val{P}) where {P} = extract_derivative!(
result, derivatives(f, x, l, p), p)
@inline derivative!(result, f!, y, x, l, p::Val{P}) where {P} = extract_derivative!(
result, derivatives(f!, y, x, l, p), p)

# `derivatives` API: computes all derivatives of `f` at `x` up to p `P - 1`

# Out-of-place function
@inline derivatives(f, x, l, p::Val{P}) where {P} = f(make_seed(x, l, p))

# In-place function
@inline function derivatives(f!, y, x, l, p::Val{P}) where {P}
buffer = similar(y, TaylorScalar{eltype(y), P})
f!(buffer, make_seed(x, l, p))
Expand Down
28 changes: 14 additions & 14 deletions src/primitive.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ end

## Hand-written exp, sin, cos

@to_static function exp(t::TaylorScalar{T, P}) where {P, T}
@immutable function exp(t::TaylorScalar{T, P}) where {P, T}
f = flatten(t)
v[0] = exp(f[0])
for i in 1:P
Expand All @@ -58,7 +58,7 @@ end
end

for func in (:sin, :cos)
@eval @to_static function $func(t::TaylorScalar{T, P}) where {T, P}
@eval @immutable function $func(t::TaylorScalar{T, P}) where {T, P}
f = flatten(t)
s[0], c[0] = sincos(f[0])
for i in 1:P
Expand Down Expand Up @@ -104,7 +104,7 @@ end
@inline -(a::TaylorScalar, b::TaylorScalar) = TaylorScalar(
value(a) - value(b), map(-, partials(a), partials(b)))

@to_static function *(a::TaylorScalar{T, P}, b::TaylorScalar{T, P}) where {T, P}
@immutable function *(a::TaylorScalar{T, P}, b::TaylorScalar{T, P}) where {T, P}
va, vb = flatten(a), flatten(b)
for i in 0:P
v[i] = zero(T)
Expand All @@ -115,7 +115,7 @@ end
TaylorScalar(v)
end

@to_static function /(a::TaylorScalar{T, P}, b::TaylorScalar{T, P}) where {T, P}
@immutable function /(a::TaylorScalar{T, P}, b::TaylorScalar{T, P}) where {T, P}
va, vb = flatten(a), flatten(b)
v[0] = va[0] / vb[0]
for i in 1:P
Expand All @@ -130,13 +130,13 @@ end

@inline literal_pow(::typeof(^), x::TaylorScalar, ::Val{0}) = one(x)
@inline literal_pow(::typeof(^), x::TaylorScalar, ::Val{1}) = x
@inline literal_pow(::typeof(^), x::TaylorScalar, ::Val{2}) = x*x
@inline literal_pow(::typeof(^), x::TaylorScalar, ::Val{3}) = x*x*x
@inline literal_pow(::typeof(^), x::TaylorScalar, ::Val{2}) = x * x
@inline literal_pow(::typeof(^), x::TaylorScalar, ::Val{3}) = x * x * x
@inline literal_pow(::typeof(^), x::TaylorScalar, ::Val{-1}) = inv(x)
@inline literal_pow(::typeof(^), x::TaylorScalar, ::Val{-2}) = (i=inv(x); i*i)
@inline literal_pow(::typeof(^), x::TaylorScalar, ::Val{-2}) = (i = inv(x); i * i)

for R in (Integer, Real)
@eval @to_static function ^(t::TaylorScalar{T, P}, n::S) where {S <: $R, T, P}
@eval @immutable function ^(t::TaylorScalar{T, P}, n::S) where {S <: $R, T, P}
f = flatten(t)
v[0] = f[0]^n
for i in 1:P
Expand All @@ -153,14 +153,14 @@ end

^(t::TaylorScalar, s::TaylorScalar) = exp(s * log(t))

@inline function lower(t::TaylorScalar{T, P}) where {T, P}
@inline function differentiate(t::TaylorScalar{T, P}) where {T, P}
s = partials(t)
TaylorScalar(ntuple(i -> s[i] * i, Val(P)))
end
@inline function higher(t::TaylorScalar{T, P}) where {T, P}
@inline function integrate(t::TaylorScalar{T, P}, C::T) where {T, P}
s = flatten(t)
ntuple(i -> s[i] / i, Val(P + 1))
TaylorScalar(C, ntuple(i -> s[i] / i, Val(P + 1)))
end
@inline raise(f, df::TaylorScalar, t) = TaylorScalar(f, higher(lower(t) * df))
@inline raise(f, df::Number, t) = df * t
@inline raiseinv(f, df, t) = TaylorScalar(f, higher(lower(t) / df))
@inline raise(f0, d::TaylorScalar, t) = integrate(differentiate(t) * d, f0)
@inline raise(f0, d::Number, t) = d * t
@inline raiseinv(f0, d, t) = integrate(differentiate(t) / d, f0)
1 change: 1 addition & 0 deletions src/scalar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ Convenience function: construct a Taylor polynomial with zeroth and first order
TaylorScalar{P}(value::T, seed::T) where {T, P} = TaylorScalar(
value, ntuple(i -> i == 1 ? seed : zero(T), Val(P)))

# Truncate or extend the order of a Taylor polynomial.
function TaylorScalar{P}(t::TaylorScalar{T, Q}) where {T, P, Q}
v = value(t)
p = partials(t)
Expand Down
18 changes: 16 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# This file is a bunch of compiler magics to cleverly define pushforward rules.
# If you are only interested in data structures and pushforward rules, you can skip this file.

using ChainRules
using ChainRulesCore
using Symbolics: @variables, @rule, unwrap, isdiv
Expand All @@ -6,7 +9,9 @@ using MacroTools
using MacroTools: prewalk, postwalk

"""
Pick a strategy for raising the derivative of a function. If the derivative is like 1 over something, raise with the division rule; otherwise, raise with the multiplication rule.
Pick a strategy for raising the derivative of a function.
If the derivative is like 1 over something, raise with the division rule;
otherwise, raise with the multiplication rule.
"""
function get_term_raiser(func)
@variables z
Expand Down Expand Up @@ -95,7 +100,16 @@ function process(d, expr)
end
end

macro to_static(def)
"""
immutable(def)
Transform a function definition to a @generated function.
1. Allocations are removed by replacing the output with scalar variables;
2. Loops are unrolled;
3. Indices are modified to use 1-based indexing;
"""
macro immutable(def)
dict = splitdef(def)
pairs = Any[]
for symbol in dict[:whereparams]
Expand Down
14 changes: 7 additions & 7 deletions test/derivative.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@

@testset "O-function, O-derivative" begin
g(x) = x^3
@test derivative(g, 1.0, 1) 3
@test derivative(g, 1.0, Val(1)) 3

h(x) = x .^ 3
@test derivative(h, [2.0 3.0], 1) [12.0 27.0]
@test derivative(h, [2.0 3.0], [1.0 1.0], Val(1)) [12.0 27.0]

g1(x) = x[1] * x[1] + x[2] * x[2]
@test derivative(g1, [1.0, 2.0], [1.0, 0.0], 1) 2.0
@test derivative(g1, [1.0, 2.0], [1.0, 0.0], Val(1)) 2.0

h1(x) = sum(x, dims = 1)
@test derivative(h1, [1.0 2.0; 2.0 3.0], [1.0 1.0; 1.0 1.0], 1) [2.0 2.0]
@test derivative(h1, [1.0 2.0; 2.0 3.0], [1.0 1.0; 1.0 1.0], Val(1)) [2.0 2.0]
end

@testset "I-function, O-derivative" begin
Expand All @@ -20,12 +20,12 @@ end
end
x = 2.0
y = [0.0, 0.0]
@test derivative(g!, y, x, 1.0, Val{1}()) [4.0, 1.0]
@test derivative(g!, y, x, 1.0, Val(1)) [4.0, 1.0]
end

@testset "O-function, I-derivative" begin
g(x) = x .^ 2
@test derivative!(zeros(2), g, [1.0, 2.0], [1.0, 0.0], Val{1}()) [2.0, 0.0]
@test derivative!(zeros(2), g, [1.0, 2.0], [1.0, 0.0], Val(1)) [2.0, 0.0]
end

@testset "I-function, I-derivative" begin
Expand All @@ -35,5 +35,5 @@ end
end
x = [2.0, 3.0]
y = [0.0, 0.0]
@test derivative!(y, g!, zeros(2), x, [1.0, 0.0], Val{1}()) [4.0, 0.0]
@test derivative!(y, g!, zeros(2), x, [1.0, 0.0], Val(1)) [4.0, 0.0]
end
Loading

0 comments on commit b4c8780

Please sign in to comment.