-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Assume commutative multiplication exactly when necessary #540
base: main
Are you sure you want to change the base?
Conversation
@@ -10,7 +10,7 @@ end | |||
function rrule(::typeof(inv), x::AbstractArray) | |||
Ω = inv(x) | |||
function inv_pullback(ΔΩ) | |||
return NoTangent(), -Ω' * ΔΩ * Ω' | |||
return NoTangent(), Ω' * -ΔΩ * Ω' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can I ask why you moved the minus?
If it was -true * Ω' * ΔΩ * Ω'
then I think you'd save a copy (since this gets fused into mul!
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So that if ΔΩ
is an AbstractZero
or a UniformScaling
, then the negation is cheaper.
If it was
-true * Ω' * ΔΩ * Ω'
then I think you'd save a copy (since this gets fused intomul!
).
I didn't follow this. How is this fused into the mul!
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(this was not an important change, and I'm happy to remove)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I didn't think about those. For dense matrices there's a 4-arg method which fuses this:
julia> f1(Ω, ΔΩ) = Ω' * -ΔΩ * Ω';
julia> f2(Ω, ΔΩ) = -true * Ω' * ΔΩ * Ω';
julia> @btime f1(Ω, ΔΩ) setup=(N=100; Ω=rand(N,N); ΔΩ=rand(N,N));
min 74.708 μs, mean 101.133 μs (6 allocations, 234.52 KiB. GC mean 6.51%)
julia> @btime f2(Ω, ΔΩ) setup=(N=100; Ω=rand(N,N); ΔΩ=rand(N,N));
min 73.125 μs, mean 92.756 μs (4 allocations, 156.34 KiB. GC mean 4.82%)
julia> @which -1 * ones(2,2) * ones(2,2) * ones(2,2)
*(α::Union{Real, Complex}, B::AbstractMatrix{<:Union{Real, Complex}}, C::AbstractMatrix{<:Union{Real, Complex}}, D::AbstractMatrix{<:Union{Real, Complex}}) in LinearAlgebra at /Users/me/.julia/dev/julia/usr/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:1134
But with I
, no fusion, hence f2
is slower. Maybe *
should have some extra methods for cases with I
.
This reverts commit 0cb446f.
I will leave this to @mcabbott to review. |
They cause tests to fail and should be added in a separate PR
I propose we explicitly test rules with Thoughts, @mcabbott? |
Implementing them here (like Base) sounds fine, but is testing with Quaternions going to quadruple the time for tests to run? Would be nice to avoid that if possible. This closes #275 presume? Not deep, but I have notation comments:
(To my eye, |
function rrule(::typeof(muladd), x::Number, y::Number, z::Number) | ||
projectx, projecty, projectz = ProjectTo(x), ProjectTo(y), ProjectTo(z) | ||
muladd_pullback(ΔΩ) = NoTangent(), projectx(ΔΩ * y'), projecty(x' * ΔΩ), projectz(ΔΩ) | ||
muladd(x, y, z), muladd_pullback | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
E.g. here I think this is very clear, the pattern of where the Δ
s go is important:
function frule((_, Δx, Δy, Δz), ::typeof(muladd), x::Number, y::Number, z::Number)
return muladd(x, y, z), muladd(Δx, y, muladd(x, Δy, Δz))
end
but I'd like to make the corresponding rrule
more distinct:
function rrule(::typeof(muladd), x::Number, y::Number, z::Number) | |
projectx, projecty, projectz = ProjectTo(x), ProjectTo(y), ProjectTo(z) | |
muladd_pullback(ΔΩ) = NoTangent(), projectx(ΔΩ * y'), projecty(x' * ΔΩ), projectz(ΔΩ) | |
muladd(x, y, z), muladd_pullback | |
end | |
function rrule(::typeof(muladd), x::Number, y::Number, z::Number) | |
muladd_pullback(∇Ω) = NoTangent(), ProjectTo(x)(∇Ω * y'), ProjectTo(y)(x' * ∇Ω), ProjectTo(z)(∇Ω) | |
return muladd(x, y, z), muladd_pullback | |
end |
or perhaps ∂Ω
is a sort-of lower-case ∇Ω
for scalars?
And since it closes over x,y,z
already, there is nothing gained by constructing projectors outside.
I think a bunch of scalar rules should have Quaternion tests, but probably only a few array rules, so I don't expect testing time to increase by that much (but we can check).
It supersedes #275 (forgot about that one; it's pretty stale now) and closes #504.
I think though that
I agree, the dots and bars are not great (and they're unicode characters often missing from people's devices). I seem to recall at some point we expressly advised people to use |
As noted in #504, there are a number of cases where types of rules were constrained to
CommutativeMulNumber
where commutation of multiplication did not need to be assumed. Likewise, there were places where commutativity was assumed but not enforced by a type constraint.This PR fixes #504 by removing constraints where un-needed and adding others where needed. Because the trigonometric, hyperbolic, logarithmic, and exponential function rules all assume cummutativity, this puts constraints on a _large_number of rules. It's possible we don't want to do this, because there are certainly numeric types out there that are real (and therefore commutative) but do not subtype
Real
, and in this case these would not directly hit our rules. However, the approach taken here is much safer.