Skip to content

Commit

Permalink
Assume unknown functions are non-linear in hessian_sparsity
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion committed Dec 12, 2024
1 parent ce8b3f6 commit 639daa5
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 38 deletions.
26 changes: 7 additions & 19 deletions src/diff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -646,24 +646,13 @@ let
linearity_rules = [
@rule +(~~xs) => reduce(+, filter(isidx, ~~xs), init=_scalar)
@rule *(~~xs) => reduce(*, filter(isidx, ~~xs), init=_scalar)
@rule (~f)(~x::(!isidx)) => _scalar

@rule (~f)(~x::isidx) => if haslinearity_1(~f)
combine_terms_1(linearity_1(~f), ~x)
else
error("Function of unknown linearity used: ", ~f)
end
@rule (~f)(~x) => isidx(~x) ? combine_terms_1(linearity_1(~f), ~x) : _scalar
@rule (^)(~x::isidx, ~y) => ~y isa Number && isone(~y) ? ~x : (~x) * (~x)
@rule (~f)(~x, ~y) => begin
if haslinearity_2(~f)
a = isidx(~x) ? ~x : _scalar
b = isidx(~y) ? ~y : _scalar
combine_terms_2(linearity_2(~f), a, b)
else
error("Function of unknown linearity used: ", ~f)
end
end
@rule ~x::issym => 0]
@rule (~f)(~x, ~y) => combine_terms_2(linearity_2(~f), isidx(~x) ? ~x : _scalar, isidx(~y) ? ~y : _scalar)

@rule ~x::issym => 0
]
linearity_propagator = Fixpoint(Postwalk(Chain(linearity_rules); maketerm=basic_mkterm))

global hessian_sparsity
Expand Down Expand Up @@ -696,9 +685,8 @@ let
@assert !(expr isa AbstractArray)
expr = value(expr)
u = map(value, vars)
idx(i) = TermCombination(Set([Dict(i=>1)]))
dict = Dict(u .=> idx.(1:length(u)))
f = Rewriters.Prewalk(x->haskey(dict, x) ? dict[x] : x; maketerm=basic_mkterm)(expr)
dict = Dict(ui => TermCombination(Set([Dict(i=>1)])) for (i, ui) in enumerate(u))
f = Rewriters.Prewalk(x-> get(dict, x, x); maketerm=basic_mkterm)(expr)
lp = linearity_propagator(f)
S = _sparse(lp, length(u))
S = full ? S : tril(S)
Expand Down
22 changes: 3 additions & 19 deletions src/linearity.jl
Original file line number Diff line number Diff line change
@@ -1,61 +1,45 @@
using SpecialFunctions
import Base.Broadcast


const linearity_known_1 = IdDict{Function,Bool}()
const linearity_known_2 = IdDict{Function,Bool}()

const linearity_map_1 = IdDict{Function, Bool}()
const linearity_map_2 = IdDict{Function, Tuple{Bool, Bool, Bool}}()

# 1-arg

const monadic_linear = [deg2rad, +, rad2deg, transpose, -, conj]

const monadic_nonlinear = [asind, log1p, acsch, erfc, digamma, acos, asec, acosh, airybiprime, acsc, cscd, log, tand, log10, csch, asinh, airyai, abs2, gamma, lgamma, erfcx, bessely0, cosh, sin, cos, atan, cospi, cbrt, acosd, bessely1, acoth, erfcinv, erf, dawson, inv, acotd, airyaiprime, erfinv, trigamma, asecd, besselj1, exp, acot, sqrt, sind, sinpi, asech, log2, tan, invdigamma, airybi, exp10, sech, erfi, coth, asin, cotd, cosd, sinh, abs, besselj0, csc, tanh, secd, atand, sec, acscd, cot, exp2, expm1, atanh, slog, ssqrt, scbrt]

# We store 3 bools even for 1-arg functions for type stability
const three_trues = (true, true, true)
for f in monadic_linear
linearity_known_1[f] = true
linearity_map_1[f] = true
end

for f in monadic_nonlinear
linearity_known_1[f] = true
linearity_map_1[f] = false
end

# 2-arg
for f in [+, rem2pi, -, >, isless, <, isequal, max, min, convert, <=, >=]
linearity_known_2[f] = true
linearity_map_2[f] = (true, true, true)
end

for f in [*]
linearity_known_2[f] = true
linearity_map_2[f] = (true, true, false)
end

for f in [/]
linearity_known_2[f] = true
linearity_map_2[f] = (true, false, false)
end
for f in [\]
linearity_known_2[f] = true
linearity_map_2[f] = (false, true, false)
end

for f in [hypot, atan, mod, rem, lbeta, ^, beta]
linearity_known_2[f] = true
linearity_map_2[f] = (false, false, false)
end

haslinearity_1(@nospecialize(f)) = get(linearity_known_1, f, false)
haslinearity_2(@nospecialize(f)) = get(linearity_known_2, f, false)

linearity_1(@nospecialize(f)) = linearity_map_1[f]
linearity_2(@nospecialize(f)) = linearity_map_2[f]
# Fallback assumption: Function is not linear, i.e., derivatives are non-zero
linearity_1(@nospecialize(f)) = get(linearity_map_1, f, false)
linearity_2(@nospecialize(f)) = get(linearity_map_2, f, (false, false, false))

# TermCombination datastructure

Expand Down
58 changes: 58 additions & 0 deletions test/diff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -407,3 +407,61 @@ let
@test isequal(expand_derivatives(D(Symbolics.scbrt(1 + x ^ 2))), simplify((2x) / (3Symbolics.scbrt(1 + x^2)^2)))
@test isequal(expand_derivatives(D(Symbolics.slog(1 + x ^ 2))), simplify((2x) / (1 + x ^ 2)))
end

# Hessian sparsity involving unknown functions
let
@variables x₁ x₂ p q[1:1]
expr = 3x₁^2 + 4x₁ * x₂
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true false]

expr = 3x₁^2 + 4x₁ * x₂ + p
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true false]

# issue 643: example test2_num
expr = 3x₁^2 + 4x₁ * x₂ + q[1]
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true false]

# Custom function: By default assumed to be non-linear
myexp(x) = exp(x)
@register_symbolic myexp(x)
expr = 3x₁^2 + 4x₁ * x₂ + myexp(p)
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true false]
expr = 3x₁^2 + 4x₁ * x₂ + myexp(x₂)
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true true]

mylogaddexp(x, y) = log(exp(x) + exp(y))
@register_symbolic mylogaddexp(x, y)
expr = 3x₁^2 + 4x₁ * x₂ + mylogaddexp(p, 2)
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true false]
expr = 3x₁^2 + 4x₁ * x₂ + mylogaddexp(3, p)
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true false]
expr = 3x₁^2 + 4x₁ * x₂ + mylogaddexp(p, 2)
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true false]
expr = 3x₁^2 + 4x₁ * x₂ + mylogaddexp(p, q[1])
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true false]
expr = 3x₁^2 + 4x₁ * x₂ + mylogaddexp(p, x₂)
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true true]
expr = 3x₁^2 + 4x₁ * x₂ + mylogaddexp(x₂, 4)
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true true]

# Custom linear function: Possible to extend `Symbolics.linearity_1`/`Symbolics.linearity_2`
myidentity(x) = x
@register_symbolic myidentity(x)
Symbolics.linearity_1(::typeof(myidentity)) = true
expr = 3x₁^2 + 4x₁ * x₂ + myidentity(p)
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true false]
expr = 3x₁^2 + 4x₁ * x₂ + myidentity(q[1])
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true false]
expr = 3x₁^2 + 4x₁ * x₂ + myidentity(x₂)
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true false]

mymul1plog(x, y) = x * (1 + log(y))
@register_symbolic mymul1plog(x, y)
Symbolics.linearity_2(::typeof(mymul1plog)) = (true, false, false)
expr = 3x₁^2 + 4x₁ * x₂ + mymul1plog(p, q[1])
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true false]
expr = 3x₁^2 + 4x₁ * x₂ + mymul1plog(x₂, q[1])
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true false]
expr = 3x₁^2 + 4x₁ * x₂ + mymul1plog(q[1], x₂)
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true true]
end

0 comments on commit 639daa5

Please sign in to comment.