-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
AD fix for PDBijector #280
Conversation
@@ -19,6 +19,65 @@ cholesky_factor(X::Cholesky) = X.U | |||
cholesky_factor(X::UpperTriangular) = X | |||
cholesky_factor(X::LowerTriangular) = X | |||
|
|||
# TODO: Add `check` as an argument? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is the last remaining question @devmotion . I'm thinking "let's not, until we start using it"?
src/utils.jl
Outdated
This is a thin wrapper around `cholesky(Hermitian(X)).L` | ||
but with a custom `ChainRulesCore.rrule` implementation. | ||
""" | ||
cholesky_lower(X::AbstractMatrix) = lower_triangular(parent(cholesky(Hermitian(X)).L)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wrap in Hermitian
to effectively do the same as the current implementation of cholesky_factor
but I believe cholesky(::Hermitian)
is only valid starting from Julia 1.8 (going by a comment in BijectorsReverseDiffExt), so we need to fix this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is actually not a problem anymore since we're now defining the adjoint to circumvent the cholesky
on tracked completely.
Is there a particular reason for defining a ChainRule-rule and applying it to ReverseDiff even though the only (?) broken backends are ReverseDiff and Tracker? It would seem a bit more natural to define a rule for ReverseDiff directly (ignoring Tracker as discussed in Turing). |
Nah. I had the same thought, but a) I'm fairly familiar with defining rrules using ChainRules, and less so with ReverseDiff, b) it's easy to make use of the |
ext/BijectorsReverseDiffExt.jl
Outdated
@grad_from_chainrules Bijectors.cholesky_lower(X::TrackedMatrix) | ||
@grad_from_chainrules Bijectors.cholesky_upper(X::TrackedMatrix) | ||
|
||
# TODO: Type-piracy; probably shouldn't do this. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, this should really not be defined in Bijectors. I'm sure this will lead to surprising debugging and issues when Bijectors is (not) loaded.
One can just define the pullback in a function and reuse it without defining an |
But isn't it fair to assume that an But I'm happy to not use a EDIT: I have the change to not using With that being said, I will still make the change because I ran into a super-weird error with ReverseDiff (I'll raise an issue in a sec) 🤦 |
of AD rules without type piracy
other tests for the sake of reproducing ReverseDiff bug
…remove rules ChainRules defs
Regarding JuliaDiff/ReverseDiff.jl#236, the macro has always had issues and limitations (one bug was just fixed recently), so in my experience it's not really the case (and possibly a too strong expectation) that it brings CR-compatibility to ReverseDiff. |
ForwardDiff as per suggestion of @devmotion
Yeaaah but it's so darn convenient 😞 |
EDIT: Naaah, I'm stupid. The |
So this is quite confusing. I'm trying to transition julia> d = 4
4
julia> dist = LKJ(d, 2.0)
LKJ{Float64, Int64}(
d: 4
η: 2.0
)
julia> b = bijector(dist)
Bijectors.VecCorrBijector()
julia> x = rand(dist)
4×4 Matrix{Float64}:
1.0 -0.19264 -0.63806 0.0930006
-0.19264 1.0 0.259633 -0.168056
-0.63806 0.259633 1.0 0.170947
0.0930006 -0.168056 0.170947 1.0
julia> # (✓) Works!
Zygote.gradient(x) do x
sum(cholesky(Hermitian(x)).U)
end
([1.141585190493663 1.7498257982625538 1.7225797690347182 1.6454362799052624; 0.0 0.5205946318942621 1.0163692667364 1.0672427905346635; 0.0 0.0 0.4745109346480554 0.8467387954612068; 0.0 0.0 0.0 0.5399333180126097],)
julia> # (×) Fails!
Zygote.gradient(x) do x
sum(parent(cholesky(Hermitian(x)).U))
end
ERROR: MethodError: no method matching UpperTriangular(::NamedTuple{(:data,), Tuple{FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}})
Closest candidates are:
UpperTriangular(::UpperTriangular)
@ LinearAlgebra ~/.julia/juliaup/julia-1.9.2+0.x64.linux.gnu/share/julia/stdlib/v1.9/LinearAlgebra/src/triangular.jl:21
UpperTriangular(::ChainRulesCore.AbstractThunk)
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/0t04l/src/tangent_types/thunks.jl:68
UpperTriangular(::TrackedMatrix)
@ DistributionsADTrackerExt ~/.julia/packages/DistributionsAD/Ufc05/ext/DistributionsADTrackerExt.jl:131
...
Stacktrace:
[1] (::Zygote.var"#1010#1013"{Cholesky{Float64, Matrix{Float64}}})(Δ::NamedTuple{(:data,), Tuple{FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}})
@ Zygote ~/.julia/packages/Zygote/JeHtr/src/lib/array.jl:630
[2] (::Zygote.var"#3461#back#1014"{Zygote.var"#1010#1013"{Cholesky{Float64, Matrix{Float64}}}})(Δ::NamedTuple{(:data,), Tuple{FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}})
@ Zygote ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71
[3] Pullback
@ ./REPL[29]:3 [inlined]
[4] (::Zygote.Pullback{Tuple{var"#51#52", Matrix{Float64}}, Tuple{Zygote.var"#3461#back#1014"{Zygote.var"#1010#1013"{Cholesky{Float64, Matrix{Float64}}}}, Zygote.Pullback{Tuple{typeof(cholesky), Hermitian{Float64, Matrix{Float64}}}, Tuple{Zygote.ZBack{ChainRules.var"#cholesky_HermOrSym_pullback#2122"{Hermitian{Float64, Matrix{Float64}}, Cholesky{Float64, Matrix{Float64}}}}, Zygote.Pullback{Tuple{Type{NoPivot}}, Tuple{}}}}, Zygote.var"#3027#back#778"{Zygote.var"#772#776"{Matrix{Float64}}}, Zygote.Pullback{Tuple{typeof(parent), UpperTriangular{Float64, Matrix{Float64}}}, Tuple{Zygote.var"#2184#back#299"{Zygote.var"#back#298"{:data, Zygote.Context{false}, UpperTriangular{Float64, Matrix{Float64}}, Matrix{Float64}}}}}, Zygote.var"#3299#back#918"{Zygote.var"#back#917"{Hermitian{Float64, Matrix{Float64}}}}}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface2.jl:0
[5] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{var"#51#52", Matrix{Float64}}, Tuple{Zygote.var"#3461#back#1014"{Zygote.var"#1010#1013"{Cholesky{Float64, Matrix{Float64}}}}, Zygote.Pullback{Tuple{typeof(cholesky), Hermitian{Float64, Matrix{Float64}}}, Tuple{Zygote.ZBack{ChainRules.var"#cholesky_HermOrSym_pullback#2122"{Hermitian{Float64, Matrix{Float64}}, Cholesky{Float64, Matrix{Float64}}}}, Zygote.Pullback{Tuple{Type{NoPivot}}, Tuple{}}}}, Zygote.var"#3027#back#778"{Zygote.var"#772#776"{Matrix{Float64}}}, Zygote.Pullback{Tuple{typeof(parent), UpperTriangular{Float64, Matrix{Float64}}}, Tuple{Zygote.var"#2184#back#299"{Zygote.var"#back#298"{:data, Zygote.Context{false}, UpperTriangular{Float64, Matrix{Float64}}, Matrix{Float64}}}}}, Zygote.var"#3299#back#918"{Zygote.var"#back#917"{Hermitian{Float64, Matrix{Float64}}}}}}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface.jl:45
[6] gradient(f::Function, args::Matrix{Float64})
@ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface.jl:97
[7] top-level scope
@ REPL[29]:2 which then also breaks |
Seems like something similar to https://github.com/FluxML/Zygote.jl/blob/29fa32a688fb4454d2ee81b9ba2a6484e4468bda/src/lib/array.jl#L356-L357 is needed. Adding those, I indeed manage to get stuff running, but the resulting adjoint type is julia> Zygote.@adjoint LinearAlgebra.parent(x::UpperTriangular) = parent(x), Δ -> (UpperTriangular(Δ),)
julia> Zygote.@adjoint LinearAlgebra.parent(x::LowerTriangular) = parent(x), Δ -> (LowerTriangular(Δ),)
julia> x
2×2 Matrix{Float64}:
1.0 -0.612002
-0.612002 1.0
julia> last(Zygote.gradient(x) do x
sum(parent(cholesky(Hermitian(x)).U))
end)
2×2 UpperTriangular{Float64, Matrix{Float64}}:
1.0428 1.77385
⋅ 0.632226
julia> last(Zygote.gradient(x) do x
sum(parent(cholesky(Hermitian(x)).L))
end)
2×2 UpperTriangular{Float64, Matrix{Float64}}:
1.0428 1.77385
⋅ 0.632226 EDIT: I believe the "issue" of returning julia> last(Zygote.jacobian(x) do x
parent(cholesky(Hermitian(x)).U)
end)
4×4 Matrix{Float64}:
0.5 0.0 0.0 0.0
0.0 0.0 0.0 0.0
0.306001 0.0 1.0 0.0
0.236798 0.0 0.773848 0.632226
julia> last(Zygote.jacobian(x) do x
parent(cholesky(Hermitian(x)).L)
end)
4×4 Matrix{Float64}:
0.5 0.0 0.0 0.0
0.306001 0.0 1.0 0.0
0.0 0.0 0.0 0.0
0.236798 0.0 0.773848 0.632226 |
* removed redundant imports to BijectorsZygoteExt * use cholesky_upper and cholesky_lower instead of cholesky_factor, etc. * added tests for CorrVecBijector * name testset correctly * use cholesky_lower and cholesky_upper instead of cholesky_factor * removed now-redundant cholesky_factor * Fix obsolete function references in tests. (#282) * Update chainrules.jl * Update corr.jl * Revert changes to transform. * removed type-piracy that has been addressed upstream and bumped Zygote version in test * use :L for Hermitian in `cholesky_lower` * fixed ForwardDiff tests for LKJCholesky * fixed tests for matrix dists and added tests for both values of uplo in LKJCholesky tests * another attempt at fixing Julia 1.6 tests --------- Co-authored-by: Hong Ge <[email protected]>
Addresses the issues described in TuringLang/Turing.jl#2018 (comment) for ReverseDiff.jl by adding a few new functions and corresponding rrules:
cholesky_lower
/cholesky_upper
: does whatcholesky_factor
does, but always returning a rawMatrix
, thus making it compatible with the likes of ReverseDiff.permutedims
to ChainRules.In addition, fixes a missing import in ReverseDiffExt (IMO we should never
import
but instead always qualify the methods we're overloading in extensions to avoid these sorts of issues, but I'll make a separate PR for this after this has gone through).Note that this does not address the issue for Tracker.jl.