-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Handle commutativity correctly in scalar rules #275
base: main
Are you sure you want to change the base?
Conversation
Current use of imag and im assumes real or complex
This reverts commit d0d1954.
will review tomorrow |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, I don't fully understand it.
Needs to bump the version.
Only major comment is that we should move more of the testing stuff into ChainRulesTestUtils.
(and/or FiniteDifferences.jl)
Approving this one now so once we have that sorted in
JuliaDiff/ChainRulesTestUtils.jl#61
it can be merged
@@ -22,6 +22,8 @@ if VERSION < v"1.3.0-DEV.142" | |||
import LinearAlgebra: dot | |||
end | |||
|
|||
# numbers that we know commute under multiplication | |||
const CommutativeMulNumber = Union{Real,Complex} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought the style guide said to put spaces here.
But now that i look, I am not sure that it mentions it JuliaDiff/BlueStyle#77
Still it is what we do else-where
const CommutativeMulNumber = Union{Real,Complex} | |
const CommutativeMulNumber = Union{Real, Complex} |
@scalar_rule acosh(x) inv(sqrt(x - 1) * sqrt(x + 1)) | ||
@scalar_rule acoth(x) inv(1 - x ^ 2) | ||
@scalar_rule acsch(x) -(inv(x ^ 2 * sqrt(1 + x ^ -2))) | ||
@scalar_rule acosh(x::CommutativeMulNumber) inv(sqrt(x - 1) * sqrt(x + 1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the logic used to determine if Multiplicative Commutative is needed for univariate functions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the extension of a function to the complex numbers and matrices is in the form of a power series, then the non-commutativity becomes a problem for non-commutative numbers, and I restrict it.
@@ -140,7 +140,36 @@ end | |||
), | |||
(!(islow | ishigh), islow, ishigh), | |||
) | |||
@scalar_rule x \ y (-(Ω / x), one(y) / x) | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we move the other scalar rule for muladd
to be here also?
|
||
# product rule requires special care for arguments where `muladd` is non-commutative | ||
function frule((_, Δx, Δy, Δz), ::typeof(muladd), x::Number, y::Number, z::Number) | ||
∂xyz = muladd(Δx, y, muladd(x, Δy, Δz)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we use MulAddMacro.jl here?
Its a dependency of ChainRulesCore already.
I think then we can just write:
∂xyz = muladd(Δx, y, muladd(x, Δy, Δz)) | |
@muladd ∂xyz = Δx*y + x*Δy + Δz |
and then i think the macro takes care of rearranging.
idk if that really adds clarity or not. What do you thing?
function FiniteDifferences.to_vec(q::Quaternion) | ||
function Quaternion_from_vec(q_vec) | ||
return Quaternion(q_vec[1], q_vec[2], q_vec[3], q_vec[4]) | ||
end | ||
return [q.s, q.v1, q.v2, q.v3], Quaternion_from_vec | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should just have this defined in ChainRulesTestUtils.jl
Whole point of that package is to avoid defining re-usable stuff inside the tests.
Its still type piracy there but it makes sense for us to define proper testing functionality for testing this using ChainRulesTestUtils.
JuliaDiff/ChainRulesTestUtils.jl#61
Or we could move it to FiniteDifferences.jl that would also be acceptable, and not type-piracy.
end | ||
@testset "/(::Quaternion, ::Real)" begin | ||
x, ẋ = quatrand(), quatrand(), quatrand() | ||
y, ẏ = randn(3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo?
y, ẏ = randn(3) | |
y, ẏ = randn(2) |
rrule_test(f, ΔΩ, (x, x̄), (y, ȳ)) | ||
end | ||
@testset "/(::Quaternion, ::Real)" begin | ||
x, ẋ = quatrand(), quatrand(), quatrand() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
x, ẋ = quatrand(), quatrand(), quatrand() | |
x, ẋ = quatrand(), quatrand() |
x, ẋ, x̄ = randn(3) | ||
y, ẏ, ȳ = quatrand(), quatrand(), quatrand() | ||
ΔΩ = quatrand() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since we are not testing rrule
don't need these/
x, ẋ, x̄ = randn(3) | |
y, ẏ, ȳ = quatrand(), quatrand(), quatrand() | |
ΔΩ = quatrand() | |
x, ẋ = randn(2) | |
y, ẏ = quatrand(), quatrand() |
y, ẏ = randn(3) | ||
frule_test(/, (x, ẋ), (y, ẏ)) | ||
# don't test rrule, because it doesn't project adjoint of y to the reals | ||
# so fd won't agree |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have a way to test these?
Why don't we run into this problem for Complex numbers?
Many scalar rules defined on arguments of
Number
type assume the scalar commutes under multiplication, which fails for non-commutative numbers like quaternions. This PR restricts the type of such rules toUnion{Real,Complex}
. Where possible, it also adds generic rules defined forNumber
that don't assume commutativity.If a user had implemented their own commutative number type that was not a
Real
, then before this PR, the rules may have worked for them, but now they will not. Hence, this is marked as a breaking change.This PR requires JuliaDiff/ChainRulesTestUtils.jl#61.