Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle commutativity correctly in scalar rules #275

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

sethaxen
Copy link
Member

@sethaxen sethaxen commented Oct 7, 2020

Many scalar rules defined on arguments of Number type assume the scalar commutes under multiplication, which fails for non-commutative numbers like quaternions. This PR restricts the type of such rules to Union{Real,Complex}. Where possible, it also adds generic rules defined for Number that don't assume commutativity.

If a user had implemented their own commutative number type that was not a Real, then before this PR, the rules may have worked for them, but now they will not. Hence, this is marked as a breaking change.

This PR requires JuliaDiff/ChainRulesTestUtils.jl#61.

@sethaxen sethaxen requested a review from oxinabox October 7, 2020 05:30
@oxinabox
Copy link
Member

oxinabox commented Oct 7, 2020

will review tomorrow

Copy link
Member

@oxinabox oxinabox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, I don't fully understand it.
Needs to bump the version.

Only major comment is that we should move more of the testing stuff into ChainRulesTestUtils.
(and/or FiniteDifferences.jl)

Approving this one now so once we have that sorted in
JuliaDiff/ChainRulesTestUtils.jl#61
it can be merged

@@ -22,6 +22,8 @@ if VERSION < v"1.3.0-DEV.142"
import LinearAlgebra: dot
end

# numbers that we know commute under multiplication
const CommutativeMulNumber = Union{Real,Complex}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought the style guide said to put spaces here.
But now that i look, I am not sure that it mentions it JuliaDiff/BlueStyle#77
Still it is what we do else-where

Suggested change
const CommutativeMulNumber = Union{Real,Complex}
const CommutativeMulNumber = Union{Real, Complex}

src/rulesets/Base/base.jl Show resolved Hide resolved
@scalar_rule acosh(x) inv(sqrt(x - 1) * sqrt(x + 1))
@scalar_rule acoth(x) inv(1 - x ^ 2)
@scalar_rule acsch(x) -(inv(x ^ 2 * sqrt(1 + x ^ -2)))
@scalar_rule acosh(x::CommutativeMulNumber) inv(sqrt(x - 1) * sqrt(x + 1))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the logic used to determine if Multiplicative Commutative is needed for univariate functions?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the extension of a function to the complex numbers and matrices is in the form of a power series, then the non-commutativity becomes a problem for non-commutative numbers, and I restrict it.

@@ -140,7 +140,36 @@ end
),
(!(islow | ishigh), islow, ishigh),
)
@scalar_rule x \ y (-(Ω / x), one(y) / x)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we move the other scalar rule for muladd to be here also?


# product rule requires special care for arguments where `muladd` is non-commutative
function frule((_, Δx, Δy, Δz), ::typeof(muladd), x::Number, y::Number, z::Number)
∂xyz = muladd(Δx, y, muladd(x, Δy, Δz))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use MulAddMacro.jl here?
Its a dependency of ChainRulesCore already.
I think then we can just write:

Suggested change
∂xyz = muladd(Δx, y, muladd(x, Δy, Δz))
@muladd ∂xyz = Δx*y + x*Δy + Δz

and then i think the macro takes care of rearranging.

idk if that really adds clarity or not. What do you thing?

Comment on lines +184 to +189
function FiniteDifferences.to_vec(q::Quaternion)
function Quaternion_from_vec(q_vec)
return Quaternion(q_vec[1], q_vec[2], q_vec[3], q_vec[4])
end
return [q.s, q.v1, q.v2, q.v3], Quaternion_from_vec
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should just have this defined in ChainRulesTestUtils.jl
Whole point of that package is to avoid defining re-usable stuff inside the tests.

Its still type piracy there but it makes sense for us to define proper testing functionality for testing this using ChainRulesTestUtils.
JuliaDiff/ChainRulesTestUtils.jl#61

Or we could move it to FiniteDifferences.jl that would also be acceptable, and not type-piracy.

end
@testset "/(::Quaternion, ::Real)" begin
x, ẋ = quatrand(), quatrand(), quatrand()
y, ẏ = randn(3)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo?

Suggested change
y, ẏ = randn(3)
y, ẏ = randn(2)

rrule_test(f, ΔΩ, (x, x̄), (y, ȳ))
end
@testset "/(::Quaternion, ::Real)" begin
x, ẋ = quatrand(), quatrand(), quatrand()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
x, ẋ = quatrand(), quatrand(), quatrand()
x, ẋ = quatrand(), quatrand()

Comment on lines +213 to +215
x, ẋ, x̄ = randn(3)
y, ẏ, ȳ = quatrand(), quatrand(), quatrand()
ΔΩ = quatrand()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we are not testing rrule don't need these/

Suggested change
x, ẋ, x̄ = randn(3)
y, ẏ, ȳ = quatrand(), quatrand(), quatrand()
ΔΩ = quatrand()
x, ẋ = randn(2)
y, ẏ = quatrand(), quatrand()

y, ẏ = randn(3)
frule_test(/, (x, ẋ), (y, ẏ))
# don't test rrule, because it doesn't project adjoint of y to the reals
# so fd won't agree
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a way to test these?

Why don't we run into this problem for Complex numbers?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants