Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AD fix for PDBijector #280

Merged
merged 23 commits into from
Aug 12, 2023
Merged

AD fix for PDBijector #280

merged 23 commits into from
Aug 12, 2023

Conversation

torfjelde
Copy link
Member

Addresses the issues described in TuringLang/Turing.jl#2018 (comment) for ReverseDiff.jl by adding a few new functions and corresponding rrules:

  • cholesky_lower/cholesky_upper: does what cholesky_factor does, but always returning a raw Matrix, thus making it compatible with the likes of ReverseDiff.
  • Defers ReverseDiff-differentiation of permutedims to ChainRules.

In addition, fixes a missing import in ReverseDiffExt (IMO we should never import but instead always qualify the methods we're overloading in extensions to avoid these sorts of issues, but I'll make a separate PR for this after this has gone through).

Note that this does not address the issue for Tracker.jl.

@@ -19,6 +19,65 @@ cholesky_factor(X::Cholesky) = X.U
cholesky_factor(X::UpperTriangular) = X
cholesky_factor(X::LowerTriangular) = X

# TODO: Add `check` as an argument?
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is the last remaining question @devmotion . I'm thinking "let's not, until we start using it"?

src/utils.jl Outdated
This is a thin wrapper around `cholesky(Hermitian(X)).L`
but with a custom `ChainRulesCore.rrule` implementation.
"""
cholesky_lower(X::AbstractMatrix) = lower_triangular(parent(cholesky(Hermitian(X)).L))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wrap in Hermitian to effectively do the same as the current implementation of cholesky_factor but I believe cholesky(::Hermitian) is only valid starting from Julia 1.8 (going by a comment in BijectorsReverseDiffExt), so we need to fix this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually not a problem anymore since we're now defining the adjoint to circumvent the cholesky on tracked completely.

src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
@devmotion
Copy link
Member

Is there a particular reason for defining a ChainRule-rule and applying it to ReverseDiff even though the only (?) broken backends are ReverseDiff and Tracker? It would seem a bit more natural to define a rule for ReverseDiff directly (ignoring Tracker as discussed in Turing).

@torfjelde
Copy link
Member Author

torfjelde commented Aug 6, 2023

Is there a particular reason for defining a ChainRule-rule and applying it to ReverseDiff even though the only (?) broken backends are ReverseDiff and Tracker? It would seem a bit more natural to define a rule for ReverseDiff directly (ignoring Tracker as discussed in Turing).

Nah. I had the same thought, but a) I'm fairly familiar with defining rrules using ChainRules, and less so with ReverseDiff, b) it's easy to make use of the rrule in defining a Tracker.@grad (but since we dropped official support for Tracker, we're no longer testing this), and c) I figured this wouldn't hurt Zygote perf vs. not defining this 🤷

@grad_from_chainrules Bijectors.cholesky_lower(X::TrackedMatrix)
@grad_from_chainrules Bijectors.cholesky_upper(X::TrackedMatrix)

# TODO: Type-piracy; probably shouldn't do this.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this should really not be defined in Bijectors. I'm sure this will lead to surprising debugging and issues when Bijectors is (not) loaded.

@devmotion
Copy link
Member

One can just define the pullback in a function and reuse it without defining an rrule? From the code it seems the rrule does not provide any performance improvements over the already existing rrules, so there's no benefit for ChainRules-compatible AD backends. Defining a ReverseDiff-rule is (sometimes) very similar to defining a Tracker rule (see, e.g., https://github.com/TuringLang/DistributionsAD.jl/blob/5847e86f7783ea8f745a7465c2f4f9020c729051/ext/DistributionsADReverseDiffExt.jl#L38-L43).

@torfjelde
Copy link
Member Author

torfjelde commented Aug 7, 2023

One can just define the pullback in a function and reuse it without defining an rrule? From the code it seems the rrule does not provide any performance improvements over the already existing rrules, so there's no benefit for ChainRules-compatible AD backends.

But isn't it fair to assume that an rrule will generally lead to improvements in type-stability, and thus such a rrule also benefitting the likes of Zygote?

But I'm happy to not use a rrule here, if that is preferred. I always just do it by default because we generally have good tools to ensure that this works as intended + given the amount of type-instabilities I've encountered with Zygote, I've "arrived" at the conclusion that writing a rrule to avoid tracing through the full callstack is generally considered to beneficial 🤷

EDIT: I have the change to not using rrule ready to go, but as I'm making the changes, my motivation for making the change is dwindling 🙃 Given how rarely I write rules for ReverseDiff and Tracker these days, there is genuinely an increased maintenance burden. And if this is the case for me, I imagine this is doubly so for new developers trying to contribute 😕 It's also "annoying" to remove these good ChainRules-practices from the rrule, e.g. usage of ProjectTo, just because the particular AD framework we're fixing doesn't support these (while when just "deferring" to ChainRules for these AD frameworks, these would just be no-ops anyways).

With that being said, I will still make the change because I ran into a super-weird error with ReverseDiff (I'll raise an issue in a sec) 🤦

ext/BijectorsReverseDiffExt.jl Outdated Show resolved Hide resolved
ext/BijectorsReverseDiffExt.jl Outdated Show resolved Hide resolved
@devmotion
Copy link
Member

Regarding JuliaDiff/ReverseDiff.jl#236, the macro has always had issues and limitations (one bug was just fixed recently), so in my experience it's not really the case (and possibly a too strong expectation) that it brings CR-compatibility to ReverseDiff.

test/ad/utils.jl Outdated Show resolved Hide resolved
@torfjelde
Copy link
Member Author

Regarding JuliaDiff/ReverseDiff.jl#236, the macro has always had issues and limitations (one bug was just fixed recently), so in my experience it's not really the case (and possibly a too strong expectation) that it brings CR-compatibility to ReverseDiff.

Yeaaah but it's so darn convenient 😞

@torfjelde
Copy link
Member Author

torfjelde commented Aug 7, 2023

Btw, @devmotion the fact that I'm facing issues when using @grad_from_chainrules is not the macro's fault, but it's because we're just calling value on the inputs, while @grad_from_chainrules does not (which is, arguably, the correct way of doing things).

EDIT: Naaah, I'm stupid. The value call is there 🤦

@torfjelde
Copy link
Member Author

So this is quite confusing. I'm trying to transition VecCorrBijector to also use the cholesky_lower, etc. and I'm now running into the following issue, despite CI passing for the current implementation (which doesn't have a custom adjoint for Zygote):

julia> d = 4
4

julia> dist = LKJ(d, 2.0)
LKJ{Float64, Int64}(
d: 4
η: 2.0
)


julia> b = bijector(dist)
Bijectors.VecCorrBijector()

julia> x = rand(dist)
4×4 Matrix{Float64}:
  1.0        -0.19264   -0.63806    0.0930006
 -0.19264     1.0        0.259633  -0.168056
 -0.63806     0.259633   1.0        0.170947
  0.0930006  -0.168056   0.170947   1.0

julia> # (✓) Works!
       Zygote.gradient(x) do x
           sum(cholesky(Hermitian(x)).U)
       end

([1.141585190493663 1.7498257982625538 1.7225797690347182 1.6454362799052624; 0.0 0.5205946318942621 1.0163692667364 1.0672427905346635; 0.0 0.0 0.4745109346480554 0.8467387954612068; 0.0 0.0 0.0 0.5399333180126097],)

julia> # (×) Fails!
       Zygote.gradient(x) do x
           sum(parent(cholesky(Hermitian(x)).U))
       end
ERROR: MethodError: no method matching UpperTriangular(::NamedTuple{(:data,), Tuple{FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}})

Closest candidates are:
  UpperTriangular(::UpperTriangular)
   @ LinearAlgebra ~/.julia/juliaup/julia-1.9.2+0.x64.linux.gnu/share/julia/stdlib/v1.9/LinearAlgebra/src/triangular.jl:21
  UpperTriangular(::ChainRulesCore.AbstractThunk)
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/0t04l/src/tangent_types/thunks.jl:68
  UpperTriangular(::TrackedMatrix)
   @ DistributionsADTrackerExt ~/.julia/packages/DistributionsAD/Ufc05/ext/DistributionsADTrackerExt.jl:131
  ...

Stacktrace:
 [1] (::Zygote.var"#1010#1013"{Cholesky{Float64, Matrix{Float64}}})(Δ::NamedTuple{(:data,), Tuple{FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}})
   @ Zygote ~/.julia/packages/Zygote/JeHtr/src/lib/array.jl:630
 [2] (::Zygote.var"#3461#back#1014"{Zygote.var"#1010#1013"{Cholesky{Float64, Matrix{Float64}}}})(Δ::NamedTuple{(:data,), Tuple{FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}})
   @ Zygote ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71
 [3] Pullback
   @ ./REPL[29]:3 [inlined]
 [4] (::Zygote.Pullback{Tuple{var"#51#52", Matrix{Float64}}, Tuple{Zygote.var"#3461#back#1014"{Zygote.var"#1010#1013"{Cholesky{Float64, Matrix{Float64}}}}, Zygote.Pullback{Tuple{typeof(cholesky), Hermitian{Float64, Matrix{Float64}}}, Tuple{Zygote.ZBack{ChainRules.var"#cholesky_HermOrSym_pullback#2122"{Hermitian{Float64, Matrix{Float64}}, Cholesky{Float64, Matrix{Float64}}}}, Zygote.Pullback{Tuple{Type{NoPivot}}, Tuple{}}}}, Zygote.var"#3027#back#778"{Zygote.var"#772#776"{Matrix{Float64}}}, Zygote.Pullback{Tuple{typeof(parent), UpperTriangular{Float64, Matrix{Float64}}}, Tuple{Zygote.var"#2184#back#299"{Zygote.var"#back#298"{:data, Zygote.Context{false}, UpperTriangular{Float64, Matrix{Float64}}, Matrix{Float64}}}}}, Zygote.var"#3299#back#918"{Zygote.var"#back#917"{Hermitian{Float64, Matrix{Float64}}}}}})(Δ::Float64)
   @ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface2.jl:0
 [5] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{var"#51#52", Matrix{Float64}}, Tuple{Zygote.var"#3461#back#1014"{Zygote.var"#1010#1013"{Cholesky{Float64, Matrix{Float64}}}}, Zygote.Pullback{Tuple{typeof(cholesky), Hermitian{Float64, Matrix{Float64}}}, Tuple{Zygote.ZBack{ChainRules.var"#cholesky_HermOrSym_pullback#2122"{Hermitian{Float64, Matrix{Float64}}, Cholesky{Float64, Matrix{Float64}}}}, Zygote.Pullback{Tuple{Type{NoPivot}}, Tuple{}}}}, Zygote.var"#3027#back#778"{Zygote.var"#772#776"{Matrix{Float64}}}, Zygote.Pullback{Tuple{typeof(parent), UpperTriangular{Float64, Matrix{Float64}}}, Tuple{Zygote.var"#2184#back#299"{Zygote.var"#back#298"{:data, Zygote.Context{false}, UpperTriangular{Float64, Matrix{Float64}}, Matrix{Float64}}}}}, Zygote.var"#3299#back#918"{Zygote.var"#back#917"{Hermitian{Float64, Matrix{Float64}}}}}}})(Δ::Float64)
   @ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface.jl:45
 [6] gradient(f::Function, args::Matrix{Float64})
   @ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface.jl:97
 [7] top-level scope
   @ REPL[29]:2

which then also breaks cholesky_lower 😕

@torfjelde
Copy link
Member Author

torfjelde commented Aug 7, 2023

Seems like something similar to https://github.com/FluxML/Zygote.jl/blob/29fa32a688fb4454d2ee81b9ba2a6484e4468bda/src/lib/array.jl#L356-L357 is needed.

Adding those, I indeed manage to get stuff running, but the resulting adjoint type is UpperTriangular no matter what 😕

julia> Zygote.@adjoint LinearAlgebra.parent(x::UpperTriangular) = parent(x), Δ -> (UpperTriangular(Δ),)

julia> Zygote.@adjoint LinearAlgebra.parent(x::LowerTriangular) = parent(x), Δ -> (LowerTriangular(Δ),)

julia> x
2×2 Matrix{Float64}:
  1.0       -0.612002
 -0.612002   1.0

julia> last(Zygote.gradient(x) do x
           sum(parent(cholesky(Hermitian(x)).U))
       end)
2×2 UpperTriangular{Float64, Matrix{Float64}}:
 1.0428  1.77385
        0.632226

julia> last(Zygote.gradient(x) do x
           sum(parent(cholesky(Hermitian(x)).L))
       end)
2×2 UpperTriangular{Float64, Matrix{Float64}}:
 1.0428  1.77385
        0.632226

EDIT: I believe the "issue" of returning UpperTriangular no matter what is because of Zygote's literal_property? It seems, for example, that the jacobian is correct:

julia> last(Zygote.jacobian(x) do x
           parent(cholesky(Hermitian(x)).U)
       end)
4×4 Matrix{Float64}:
 0.5       0.0  0.0       0.0
 0.0       0.0  0.0       0.0
 0.306001  0.0  1.0       0.0
 0.236798  0.0  0.773848  0.632226

julia> last(Zygote.jacobian(x) do x
           parent(cholesky(Hermitian(x)).L)
       end)
4×4 Matrix{Float64}:
 0.5       0.0  0.0       0.0
 0.306001  0.0  1.0       0.0
 0.0       0.0  0.0       0.0
 0.236798  0.0  0.773848  0.632226

* removed redundant imports to BijectorsZygoteExt

* use cholesky_upper and cholesky_lower instead of cholesky_factor, etc.

* added tests for CorrVecBijector

* name testset correctly

* use cholesky_lower and cholesky_upper instead of cholesky_factor

* removed now-redundant cholesky_factor

* Fix obsolete function references in tests.  (#282)

* Update chainrules.jl

* Update corr.jl

* Revert changes to transform.

* removed type-piracy that has been addressed upstream and bumped Zygote
version in test

* use :L for Hermitian in `cholesky_lower`

* fixed ForwardDiff tests for LKJCholesky

* fixed tests for matrix dists and added tests for both values of uplo
in LKJCholesky tests

* another attempt at fixing Julia 1.6 tests

---------

Co-authored-by: Hong Ge <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants