Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AD fix for PDBijector #280

Merged
merged 23 commits into from
Aug 12, 2023
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
3967e39
added cholesky_lower and cholesky_triangular
torfjelde Aug 4, 2023
394debc
updated PD to use new cholesky_lower and cholesky_upper
torfjelde Aug 4, 2023
64d87bf
simplified imports in BijectorsReverseDiffExtx
torfjelde Aug 4, 2023
d175513
added ChainRules as a dep since we need the chain rules for cholesky,…
torfjelde Aug 4, 2023
94f6a0e
forgot to update Project.toml in previous commit
torfjelde Aug 4, 2023
83fee94
added explicit implementation of with_logabsdet_jacobian for PDBijector
torfjelde Aug 4, 2023
4b390c6
Update src/utils.jl
torfjelde Aug 4, 2023
1185cad
added ProjectTo in rrules for cholesky_lower and cholesky_upper to be…
torfjelde Aug 4, 2023
6be9534
added ProjectTo for cholesky_upper too
torfjelde Aug 6, 2023
7675ea2
added transpose_eager as a alias for permutedims to allow definition
torfjelde Aug 7, 2023
15c47eb
allow usage of ForwardDiff gradient as ground-truth
torfjelde Aug 7, 2023
9322fda
added AD tests for PDVecBijector
torfjelde Aug 7, 2023
0bf8487
added AD tests for PDVecBijector to runtests and commented out all
torfjelde Aug 7, 2023
5d0cd2d
forgot to remove type-piracy def of ReverseDiff rule for permutedims
torfjelde Aug 7, 2023
29790dc
use ReverseDiff.@grad instead of ReverseDiff.@grad_from_chainrules
torfjelde Aug 7, 2023
3241936
only define cholesky_lower and cholesky_upper rules for ReverseDiff, …
torfjelde Aug 7, 2023
951028e
formatting
torfjelde Aug 7, 2023
4fe6085
parameterise gradient test for PD bijector properly instead of using
torfjelde Aug 7, 2023
1102266
reversed chagne to test_ad
torfjelde Aug 7, 2023
4e66a8d
reactivate tests
torfjelde Aug 7, 2023
4f1ecc8
updated doocstrings
torfjelde Aug 7, 2023
e87a2aa
improved PDVecBijector AD tests a bit
torfjelde Aug 7, 2023
52ee210
AD fix for CorrBijector (#281)
torfjelde Aug 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ version = "0.13.3"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand All @@ -22,20 +23,20 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[weakdeps]
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"

[extensions]
BijectorsDistributionsADExt = "DistributionsAD"
BijectorsForwardDiffExt = "ForwardDiff"
BijectorsLazyArraysExt = "LazyArrays"
BijectorsReverseDiffExt = "ReverseDiff"
BijectorsTrackerExt = "Tracker"
BijectorsZygoteExt = "Zygote"
BijectorsLazyArraysExt = "LazyArrays"
BijectorsDistributionsADExt = "DistributionsAD"

[compat]
ArgCheck = "1, 2"
Expand Down
58 changes: 55 additions & 3 deletions ext/BijectorsReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ if isdefined(Base, :get_extension)
simplex_logabsdetjac_gradient,
Inverse
import Bijectors:
Bijectors,
_eps,
logabsdetjac,
_logabsdetjac_scale,
Expand All @@ -35,7 +36,8 @@ if isdefined(Base, :get_extension)
find_alpha,
pd_from_lower,
lower_triangular,
upper_triangular
upper_triangular,
transpose_eager

using Bijectors.LinearAlgebra
using Bijectors.Compat: eachcol
Expand All @@ -61,6 +63,7 @@ else
simplex_logabsdetjac_gradient,
Inverse
import ..Bijectors:
Bijectors,
_eps,
logabsdetjac,
_logabsdetjac_scale,
Expand All @@ -75,7 +78,8 @@ else
find_alpha,
pd_from_lower,
lower_triangular,
upper_triangular
upper_triangular,
transpose_eager

using ..Bijectors.LinearAlgebra
using ..Bijectors.Compat: eachcol
Expand Down Expand Up @@ -253,11 +257,59 @@ end
@grad_from_chainrules _transform_ordered(y::Union{TrackedVector,TrackedMatrix})
@grad_from_chainrules _transform_inverse_ordered(x::Union{TrackedVector,TrackedMatrix})

@grad_from_chainrules update_triu_from_vec(vals::TrackedVector{<:Real}, k::Int, dim::Int)
@grad_from_chainrules Bijectors.update_triu_from_vec(
vals::TrackedVector{<:Real}, k::Int, dim::Int
)

@grad_from_chainrules _link_chol_lkj(x::TrackedMatrix)
@grad_from_chainrules _inv_link_chol_lkj(x::TrackedVector)


torfjelde marked this conversation as resolved.
Show resolved Hide resolved
cholesky_lower(X::TrackedMatrix) = track(cholesky_lower, X)
@grad function cholesky_lower(X_tracked::TrackedMatrix)
X = value(X_tracked)
H, hermitian_pullback = ChainRulesCore.rrule(Hermitian, X, :L)
C, cholesky_pullback = ChainRulesCore.rrule(cholesky, H, Val(false))
function cholesky_lower_pullback(ΔL)
ΔC = ChainRulesCore.Tangent{typeof(C)}(; factors=(C.uplo === :L ? ΔL : ΔL'))
ΔH = cholesky_pullback(ΔC)[2]
Δx = hermitian_pullback(ΔH)[2]
# No need to add pullback for `lower_triangular`, because the pullback
# for `Hermitian` already produces the correct result (i.e. the lower-triangular
# part zeroed out).
return (Δx,)
end

return lower_triangular(parent(C.L)), cholesky_lower_pullback
end

cholesky_upper(X::TrackedMatrix) = track(cholesky_upper, X)
@grad function cholesky_upper(X_tracked::TrackedMatrix)
X = value(X_tracked)
H, hermitian_pullback = ChainRulesCore.rrule(Hermitian, X, :U)
C, cholesky_pullback = ChainRulesCore.rrule(cholesky, H, Val(false))
function cholesky_upper_pullback(ΔU)
ΔC = ChainRulesCore.Tangent{typeof(C)}(; factors=(C.uplo === :U ? ΔU : ΔU'))
ΔH = cholesky_pullback(ΔC)[2]
Δx = hermitian_pullback(ΔH)[2]
# No need to add pullback for `upper_triangular`, because the pullback
# for `Hermitian` already produces the correct result (i.e. the upper-triangular
# part zeroed out).
return (Δx,)
end

return upper_triangular(parent(C.U)), cholesky_upper_pullback
end

transpose_eager(X::TrackedMatrix) = track(transpose_eager, X)
@grad function transpose_eager(X_tracked::TrackedMatrix)
X = value(X_tracked)
y, y_pullback = ChainRulesCore.rrule(permutedims, X, (2, 1))
transpose_eager_pullback(Δ) = (y_pullback(Δ)[2],)
return y, transpose_eager_pullback
end


torfjelde marked this conversation as resolved.
Show resolved Hide resolved
if VERSION <= v"1.8.0-DEV.1526"
# HACK: This dispatch does not wrap X in Hermitian before calling cholesky.
# cholesky does not work with AbstractMatrix in julia versions before the compared one,
Expand Down
1 change: 1 addition & 0 deletions src/Bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import ChangesOfVariables: ChangesOfVariables, with_logabsdet_jacobian
import InverseFunctions: inverse

using ChainRulesCore: ChainRulesCore
using ChainRules: ChainRules
using Functors: Functors
using IrrationalConstants: IrrationalConstants
using LogExpFunctions: LogExpFunctions
Expand Down
31 changes: 11 additions & 20 deletions src/bijectors/pd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,46 +8,37 @@ function replace_diag(f, X)
return g.(1:size(X, 1), (1:size(X, 2))')
end
transform(b::PDBijector, X::AbstractMatrix{<:Real}) = pd_link(X)
function pd_link(X)
Y = lower_triangular(parent(cholesky(X; check=true).L))
return replace_diag(log, Y)
end
pd_link(X) = replace_diag(log, cholesky_lower(X))

function transform(ib::Inverse{PDBijector}, Y::AbstractMatrix{<:Real})
X = replace_diag(exp, Y)
return pd_from_lower(X)
end

function logabsdetjac(b::PDBijector, X::AbstractMatrix{<:Real})
T = eltype(X)
Xcf = cholesky(X; check=false)
if !issuccess(Xcf)
Xcf = cholesky(X + max(eps(T), eps(T) * norm(X)) * I)
end
return logabsdetjac_pdbijector_chol(Xcf)
L = cholesky_lower(X)
return logabsdetjac_pdbijector_chol(L)
end

function logabsdetjac_pdbijector_chol(Xcf::Cholesky)
# NOTE: Use `UpperTriangular` here because we only need `diag(U)`
# and `UL` is by default already constructed in `Cholesky`.
UL = Xcf.UL
d = size(UL, 1)
z = sum(((d + 1):(-1):2) .* log.(diag(UL)))
function logabsdetjac_pdbijector_chol(X::AbstractMatrix)
d = size(X, 1)
z = sum(((d + 1):(-1):2) .* log.(diag(X)))
return -(z + d * oftype(z, IrrationalConstants.logtwo))
end

# TODO: Implement explicitly.
function with_logabsdet_jacobian(b::PDBijector, X)
return transform(b, X), logabsdetjac(b, X)
L = cholesky_lower(X)
return replace_diag(log, L), logabsdetjac_pdbijector_chol(L)
end

struct PDVecBijector <: Bijector end

transform(::PDVecBijector, X::AbstractMatrix{<:Real}) = pd_vec_link(X)
pd_vec_link(X) = triu_to_vec(transpose(pd_link(X)))
# TODO: Implement `tril_to_vec` and remove `permutedims`.
pd_vec_link(X) = triu_to_vec(transpose_eager(pd_link(X)))

function transform(::Inverse{PDVecBijector}, y::AbstractVector{<:Real})
Y = permutedims(vec_to_triu(y))
Y = transpose_eager(vec_to_triu(y))
return transform(inverse(PDBijector()), Y)
end

Expand Down
28 changes: 28 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,34 @@ cholesky_factor(X::Cholesky) = X.U
cholesky_factor(X::UpperTriangular) = X
cholesky_factor(X::LowerTriangular) = X

# HACK: Allows us to define custom chain rules while we wait for upstream fixes.
transpose_eager(X::AbstractMatrix) = permutedims(X)

# TODO: Add `check` as an argument?
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is the last remaining question @devmotion . I'm thinking "let's not, until we start using it"?

"""
cholesky_lower(X)

Return the lower triangular Cholesky factor of `X` as a `Matrix`
rather than `LowerTriangular`.

!!! note
This is a thin wrapper around `cholesky(Hermitian(X)).L`
but with a custom `ChainRulesCore.rrule` implementation.
"""
cholesky_lower(X::AbstractMatrix) = lower_triangular(parent(cholesky(Hermitian(X)).L))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wrap in Hermitian to effectively do the same as the current implementation of cholesky_factor but I believe cholesky(::Hermitian) is only valid starting from Julia 1.8 (going by a comment in BijectorsReverseDiffExt), so we need to fix this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually not a problem anymore since we're now defining the adjoint to circumvent the cholesky on tracked completely.


"""
cholesky_upper(X)

Return the upper triangular Cholesky factor of `X` as a `Matrix`
rather than `UpperTriangular`.

!!! note
This is a thin wrapper around `cholesky(Hermitian(X)).U`
but with a custom `ChainRulesCore.rrule` implementation.
"""
cholesky_upper(X::AbstractMatrix) = upper_triangular(parent(cholesky(Hermitian(X)).U))

"""
triu_mask(X::AbstractMatrix, k::Int)

Expand Down
16 changes: 16 additions & 0 deletions test/ad/pd.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
@testset "AD for PD bijector" begin
d = 4
dist = Wishart(4, Matrix{Float64}(Distributions.I, d, d))
x = rand(dist)
b = bijector(dist)
binv = inverse(b)
y = b(x)

test_ad(vec(x); use_forwarddiff_as_truth=true) do x
sum(transform(b, reshape(x, d, d)))
end

test_ad(y; use_forwarddiff_as_truth=true) do y
sum(transform(binv, y))
end
end
24 changes: 14 additions & 10 deletions test/ad/utils.jl
Original file line number Diff line number Diff line change
@@ -1,32 +1,36 @@
# Figure out which AD backend to test
const AD = get(ENV, "AD", "All")

function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6)
finitediff = FiniteDifferences.grad(central_fdm(5, 1), f, x)[1]
function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6, use_forwarddiff_as_truth=false)
truth = if use_forwarddiff_as_truth
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
truth = ForwardDiff.gradient(f, x)
else
FiniteDifferences.grad(central_fdm(5, 1), f, x)[1]
end

if AD == "All" || AD == "ForwardDiff"
if !use_forwarddiff_as_truth && (AD == "All" || AD == "ForwardDiff")
if :ForwardDiff in broken
@test_broken ForwardDiff.gradient(f, x) ≈ finitediff rtol = rtol atol = atol
@test_broken ForwardDiff.gradient(f, x) ≈ truth rtol = rtol atol = atol
else
@test ForwardDiff.gradient(f, x) ≈ finitediff rtol = rtol atol = atol
@test ForwardDiff.gradient(f, x) ≈ truth rtol = rtol atol = atol
end
end

if AD == "All" || AD == "Zygote"
if :Zygote in broken
@test_broken Zygote.gradient(f, x)[1] ≈ finitediff rtol = rtol atol = atol
@test_broken Zygote.gradient(f, x)[1] ≈ truth rtol = rtol atol = atol
else
∇zygote = Zygote.gradient(f, x)[1]
@test (all(finitediff .== 0) && ∇zygote === nothing) ||
isapprox(∇zygote, finitediff; rtol=rtol, atol=atol)
@test (all(truth .== 0) && ∇zygote === nothing) ||
isapprox(∇zygote, truth; rtol=rtol, atol=atol)
end
end

if AD == "All" || AD == "ReverseDiff"
if :ReverseDiff in broken
@test_broken ReverseDiff.gradient(f, x) ≈ finitediff rtol = rtol atol = atol
@test_broken ReverseDiff.gradient(f, x) ≈ truth rtol = rtol atol = atol
else
@test ReverseDiff.gradient(f, x) ≈ finitediff rtol = rtol atol = atol
@test ReverseDiff.gradient(f, x) ≈ truth rtol = rtol atol = atol
end
end

Expand Down
33 changes: 17 additions & 16 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,23 @@ const GROUP = get(ENV, "GROUP", "All")
include("ad/utils.jl")
include("bijectors/utils.jl")

if GROUP == "All" || GROUP == "Interface"
include("interface.jl")
include("transform.jl")
include("norm_flows.jl")
include("bijectors/permute.jl")
include("bijectors/rational_quadratic_spline.jl")
include("bijectors/named_bijector.jl")
include("bijectors/leaky_relu.jl")
include("bijectors/coupling.jl")
include("bijectors/ordered.jl")
include("bijectors/pd.jl")
include("bijectors/reshape.jl")
include("bijectors/corr.jl")
end
# if GROUP == "All" || GROUP == "Interface"
# include("interface.jl")
# include("transform.jl")
# include("norm_flows.jl")
# include("bijectors/permute.jl")
# include("bijectors/rational_quadratic_spline.jl")
# include("bijectors/named_bijector.jl")
# include("bijectors/leaky_relu.jl")
# include("bijectors/coupling.jl")
# include("bijectors/ordered.jl")
# include("bijectors/pd.jl")
# include("bijectors/reshape.jl")
# include("bijectors/corr.jl")
# end

if GROUP == "All" || GROUP == "AD"
include("ad/chainrules.jl")
include("ad/flows.jl")
# include("ad/chainrules.jl")
# include("ad/flows.jl")
include("ad/pd.jl")
end