From 74c4bf6c82f39d9074d170c91d51b70d0ebcdd75 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 12 Apr 2021 20:55:27 -0400 Subject: [PATCH 1/6] print ambiguities during testing --- test/runtests.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index aec5a66f..338b9835 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -41,6 +41,16 @@ using Dates: const colon = Base.:(:) +ambig = sort(detect_ambiguities(Unitful), by = a -> [string(a[1].name), string(a[2].module)]) +if length(ambig) > 0 + println(stdout, "detect_ambiguities(Unitful) found $(length(ambig)) issues:") + for i in 1:length(ambig) + println(stdout, "[",i, "]:") + println(stdout, " ", ambig[i][1]) + println(stdout, " ", ambig[i][2]) + end + println(stdout) +end @testset "Construction" begin @test isa(NoUnits, FreeUnits) @test typeof(𝐋) === Unitful.Dimensions{(Unitful.Dimension{:Length}(1),)} From e95f3e83cfb29f014c253638935435cfa9fcd2f3 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 12 Apr 2021 21:30:22 -0400 Subject: [PATCH 2/6] take 1 --- src/Unitful.jl | 1 + src/linearalgebra.jl | 42 ++++++++++++++++++++++++++++++++++++++++++ src/utils.jl | 2 +- test/runtests.jl | 34 ++++++++++++++++++++++++++++++++++ 4 files changed, 78 insertions(+), 1 deletion(-) create mode 100644 src/linearalgebra.jl diff --git a/src/Unitful.jl b/src/Unitful.jl index 343f49b6..4da15cd7 100644 --- a/src/Unitful.jl +++ b/src/Unitful.jl @@ -69,5 +69,6 @@ include("logarithm.jl") include("complex.jl") include("pkgdefaults.jl") include("dates.jl") +include("linearalgebra.jl") end diff --git a/src/linearalgebra.jl b/src/linearalgebra.jl new file mode 100644 index 00000000..fb5fd010 --- /dev/null +++ b/src/linearalgebra.jl @@ -0,0 +1,42 @@ +using LinearAlgebra + +# This function is re-defined during testing, to check we hit the fast path: +linearalgebra_count() = nothing + +function LinearAlgebra.mul!(C::StridedVecOrMat{<:AbstractQuantity{T}}, + A::StridedMatrix{<:AbstractQuantity{T}}, + B::StridedVecOrMat{<:AbstractQuantity{T}}, + alpha::Bool, beta::Bool) where {T<:Base.HWNumber} + # This is exactly how A * B creates C = similar(B, T, ...) + eltype(C) == Base.promote_op(LinearAlgebra.matprod, eltype(A), eltype(B)) || error("bad eltypes") + C0 = ustrip(C) + A0 = ustrip(A) + B0 = ustrip(B) + mul!(C0, A0, B0) + linearalgebra_count() + return C +end + +function LinearAlgebra.mul!(C::StridedVecOrMat{<:AbstractQuantity{T}}, + A::LinearAlgebra.AdjOrTransAbsMat{<:AbstractQuantity{T}, <:StridedMatrix}, + B::StridedVecOrMat{<:AbstractQuantity{T}}, + alpha::Bool, beta::Bool) where {T<:Base.HWNumber} + + eltype(C) == Base.promote_op(LinearAlgebra.matprod, eltype(A), eltype(B)) || error("bad eltypes") + C0 = ustrip(C) + A0 = A isa Adjoint ? adjoint(ustrip(parent(A))) : transpose(ustrip(parent(A))) + B0 = ustrip(B) + mul!(C0, A0, B0) + linearalgebra_count() + return C +end + +function LinearAlgebra.dot(A::StridedArray{<:AbstractQuantity{T}}, + B::StridedArray{<:AbstractQuantity{T}}) where {T<:Base.HWNumber} + A0 = ustrip(A) + B0 = ustrip(B) + C0 = dot(A0, B0) + linearalgebra_count() + C = C0 * oneunit(eltype(A)) * oneunit(eltype(B)) # surely there is an official way + return C +end diff --git a/src/utils.jl b/src/utils.jl index 53f12b95..36a25b07 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -75,7 +75,7 @@ julia> a[1] = 3u"m"; b 2 ``` """ -@inline ustrip(A::Array{Q}) where {Q <: Quantity} = reinterpret(numtype(Q), A) +@inline ustrip(A::StridedArray{Q}) where {Q <: Quantity} = reinterpret(numtype(Q), A) @deprecate(ustrip(A::AbstractArray{T}) where {T<:Number}, ustrip.(A)) diff --git a/test/runtests.jl b/test/runtests.jl index 338b9835..f9967c58 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -94,6 +94,40 @@ end @test ConstructionBase.constructorof(typeof(1.0m))(2) === 2m end +@testset "LinearAlgebra functions" begin + CNT = Ref(0) + Unitful.linearalgebra_count() = (CNT[] += 1; nothing) + @testset "> Matrix multiplication: *" begin + M = rand(3,3) .* u"m" + M_ = view(M,:,1:3) + v = rand(3) .* u"V" + v_ = view(v, 1:3) + + CNT[] = 0 + + @test unit(first(M * M)) == u"m*m" + @test M * M == M_ * M == M * M_ == M_ * M_ + + @test unit(first(M * v)) == u"m*V" + @test M * v == M_ * v == M * v_ == M_ * v_ + + @test CNT[] == 10 + + @test unit(first(v' * M)) == u"m*V" + @test v' * M == v_' * M == v_' * M == v_' * M_ + + @test CNT[] == 15 + + @test unit(v' * v) == u"V*V" + @test v' * v == v_' * v == v_' * v == v_' * v_ + + @test CNT[] == 20 + end + @testset "> Matrix multiplication: mul!" begin + + end +end + @testset "Types" begin @test Base.complex(Quantity{Float64,NoDims,NoUnits}) == Quantity{Complex{Float64},NoDims,NoUnits} From 9515ad2d496e6bfac30c0cd1d885529256bca2f2 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 12 Apr 2021 23:07:11 -0400 Subject: [PATCH 3/6] skip tests on 1.0, and add a few --- test/runtests.jl | 32 +++++++++++++++++++++++++++++--- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index f9967c58..1c7875b4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -111,20 +111,46 @@ end @test unit(first(M * v)) == u"m*V" @test M * v == M_ * v == M * v_ == M_ * v_ - @test CNT[] == 10 + VERSION >= v"1.3" && @test CNT[] == 10 @test unit(first(v' * M)) == u"m*V" @test v' * M == v_' * M == v_' * M == v_' * M_ - @test CNT[] == 15 + VERSION >= v"1.3" && @test CNT[] == 15 @test unit(v' * v) == u"V*V" @test v' * v == v_' * v == v_' * v == v_' * v_ - @test CNT[] == 20 + VERSION >= v"1.3" && @test CNT[] == 20 + + # Mixed with & without units + N = rand(3,3) + w = rand(3) + + CNT[] = 0 + + @test unit(first(M * N)) == u"m" + @test unit(first(N * M)) == u"m" + + @test unit(first(M * w)) == u"m" + @test unit(first(N * v)) == u"V" + + @show CNT[] # not specialised yet + end @testset "> Matrix multiplication: mul!" begin + A = rand(3,3) .* u"m" + B = rand(3,3) .* u"m" + C = fill(zero(eltype(A*B)), 3, 3) + CNT[] = 0 + mul!(C, A, B) + if VERSION >= v"1.3" # the 5-arm mul! exists + mul!(C, A, B, true, true) + mul!(C, A, B, 3, 7) # not specialised yet + + @show CNT[] + end end end From 8432c616aff3f109c693889616ea3ddabf5ab0e2 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 17 Apr 2021 00:38:33 -0400 Subject: [PATCH 4/6] extend to \, /, inv --- src/Unitful.jl | 4 +- src/linearalgebra.jl | 96 +++++++++++++++++++++++++++++++------------- src/utils.jl | 6 ++- test/runtests.jl | 2 +- 4 files changed, 77 insertions(+), 31 deletions(-) diff --git a/src/Unitful.jl b/src/Unitful.jl index 4da15cd7..0b9a42ce 100644 --- a/src/Unitful.jl +++ b/src/Unitful.jl @@ -21,8 +21,8 @@ import Base: steprange_last, unsigned end import Dates -import LinearAlgebra: Diagonal, Bidiagonal, Tridiagonal, SymTridiagonal -import LinearAlgebra: istril, istriu, norm +import LinearAlgebra: Diagonal, Bidiagonal, Tridiagonal, SymTridiagonal, Adjoint, Transpose, AdjOrTransAbsMat +import LinearAlgebra: istril, istriu, norm, mul!, dot, /, \, inv, pinv import Random import ConstructionBase: constructorof diff --git a/src/linearalgebra.jl b/src/linearalgebra.jl index fb5fd010..77433605 100644 --- a/src/linearalgebra.jl +++ b/src/linearalgebra.jl @@ -1,42 +1,84 @@ -using LinearAlgebra -# This function is re-defined during testing, to check we hit the fast path: -linearalgebra_count() = nothing - -function LinearAlgebra.mul!(C::StridedVecOrMat{<:AbstractQuantity{T}}, - A::StridedMatrix{<:AbstractQuantity{T}}, - B::StridedVecOrMat{<:AbstractQuantity{T}}, - alpha::Bool, beta::Bool) where {T<:Base.HWNumber} - # This is exactly how A * B creates C = similar(B, T, ...) - eltype(C) == Base.promote_op(LinearAlgebra.matprod, eltype(A), eltype(B)) || error("bad eltypes") - C0 = ustrip(C) - A0 = ustrip(A) - B0 = ustrip(B) - mul!(C0, A0, B0) - linearalgebra_count() - return C +# Multiplication + +function mul!(C::StridedVecOrMat{<:AbstractQuantity{T}}, + A::StridedMatrix{<:AbstractQuantity{T}}, + B::StridedVecOrMat{<:AbstractQuantity{T}}, + alpha::Number, beta::Number) where {T<:Base.HWNumber} + _mul!(C, A, B, alpha, beta) end -function LinearAlgebra.mul!(C::StridedVecOrMat{<:AbstractQuantity{T}}, - A::LinearAlgebra.AdjOrTransAbsMat{<:AbstractQuantity{T}, <:StridedMatrix}, - B::StridedVecOrMat{<:AbstractQuantity{T}}, - alpha::Bool, beta::Bool) where {T<:Base.HWNumber} +function mul!(C::StridedVecOrMat{<:AbstractQuantity{T}}, + A::AdjOrTransAbsMat{<:AbstractQuantity{T}, <:StridedMatrix}, + B::StridedVecOrMat{<:AbstractQuantity{T}}, + alpha::Number, beta::Number) where {T<:Base.HWNumber} + _mul!(C, A, B, alpha, beta) +end - eltype(C) == Base.promote_op(LinearAlgebra.matprod, eltype(A), eltype(B)) || error("bad eltypes") +function _mul!(C, A, B, alpha, beta) + if unit(beta) != NoUnits + throw(DimensionError("beta", 1.0)) + elseif unit(eltype(C)) != unit(eltype(A)) * unit(eltype(B)) * unit(alpha) + throw(DimensionError("A * B .* α", "C")) + end C0 = ustrip(C) - A0 = A isa Adjoint ? adjoint(ustrip(parent(A))) : transpose(ustrip(parent(A))) + A0 = ustrip(A) B0 = ustrip(B) mul!(C0, A0, B0) - linearalgebra_count() + _linearalgebra_count() return C end -function LinearAlgebra.dot(A::StridedArray{<:AbstractQuantity{T}}, - B::StridedArray{<:AbstractQuantity{T}}) where {T<:Base.HWNumber} +function dot(A::StridedArray{<:AbstractQuantity{T}}, + B::StridedArray{<:AbstractQuantity{T}}) where {T<:Base.HWNumber} A0 = ustrip(A) B0 = ustrip(B) C0 = dot(A0, B0) - linearalgebra_count() - C = C0 * oneunit(eltype(A)) * oneunit(eltype(B)) # surely there is an official way + _linearalgebra_count() + C = C0 * unit(eltype(A)) * unit(eltype(B)) return C end + +# Division + +function (\)(A::StridedMatrix{<:AbstractQuantity{T}}, + B::StridedVecOrMat{<:AbstractQuantity{T}}) where {T<:Base.HWNumber} + A0 = ustrip(A) + B0 = ustrip(B) + C0 = A0 \ B0 + _linearalgebra_count() + u = unit(eltype(B)) / unit(eltype(A)) + Tu = typeof(one(eltype(C0)) * u) + return reinterpret(Tu, C0) +end + +function (/)(A::StridedVecOrMat{<:AbstractQuantity{T}}, + B::StridedVecOrMat{<:AbstractQuantity{T}}) where {T<:Base.HWNumber} + A0 = ustrip(A) + B0 = ustrip(B) + C0 = A0 / B0 + _linearalgebra_count() + u = unit(eltype(A)) / unit(eltype(B)) + Tu = typeof(one(eltype(C0)) * u) + return reinterpret(Tu, C0) +end + +function inv(A::StridedMatrix{<:AbstractQuantity{T}}) where {T<:Base.HWNumber} + C0 = inv(ustrip(A)) + _linearalgebra_count() + u = inv(unit(eltype(A))) + Tu = typeof(one(eltype(C0)) * u) + return reinterpret(Tu, C0) +end + +function pinv(A::StridedMatrix{<:AbstractQuantity{T}}; kw...) where {T<:Base.HWNumber} + C0 = pinv(ustrip(A); kw...) + _linearalgebra_count() + u = inv(unit(eltype(A))) + Tu = typeof(one(eltype(C0)) * u) + return reinterpret(Tu, C0) +end + +# This function is re-defined during testing, to check we hit the fast path: +_linearalgebra_count() = nothing + diff --git a/src/utils.jl b/src/utils.jl index 36a25b07..0effccc0 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -50,7 +50,8 @@ true @inline ustrip(x::Missing) = missing """ - ustrip(x::Array{Q}) where {Q <: Quantity} + ustrip(x::Array{Q}) where {Q <: Quantity{T}}} + Strip units from an `Array` by reinterpreting to type `T`. The resulting `Array` is a not a copy, but rather a unit-stripped view into array `x`. Because the units are removed, information may be lost and this should be used with some care. @@ -91,6 +92,9 @@ ustrip(A::Bidiagonal) = Bidiagonal(ustrip(A.dv), ustrip(A.ev), ifelse(istriu(A), ustrip(A::Tridiagonal) = Tridiagonal(ustrip(A.dl), ustrip(A.d), ustrip(A.du)) ustrip(A::SymTridiagonal) = SymTridiagonal(ustrip(A.dv), ustrip(A.ev)) +ustrip(A::Adjoint) = adjoint(ustrip(parent(A))) +ustrip(A::Transpose) = transpose(ustrip(parent(A))) + """ unit(x::Quantity{T,D,U}) where {T,D,U} unit(x::Type{Quantity{T,D,U}}) where {T,D,U} diff --git a/test/runtests.jl b/test/runtests.jl index 1c7875b4..5ad85549 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -96,7 +96,7 @@ end @testset "LinearAlgebra functions" begin CNT = Ref(0) - Unitful.linearalgebra_count() = (CNT[] += 1; nothing) + Unitful._linearalgebra_count() = (CNT[] += 1; nothing) @testset "> Matrix multiplication: *" begin M = rand(3,3) .* u"m" M_ = view(M,:,1:3) From 9f11673338e989023fb6910e6c25ac728bd3821a Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 17 Apr 2021 00:48:19 -0400 Subject: [PATCH 5/6] rm ambiguity printing --- test/runtests.jl | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 5ad85549..4e2824b9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -41,16 +41,6 @@ using Dates: const colon = Base.:(:) -ambig = sort(detect_ambiguities(Unitful), by = a -> [string(a[1].name), string(a[2].module)]) -if length(ambig) > 0 - println(stdout, "detect_ambiguities(Unitful) found $(length(ambig)) issues:") - for i in 1:length(ambig) - println(stdout, "[",i, "]:") - println(stdout, " ", ambig[i][1]) - println(stdout, " ", ambig[i][2]) - end - println(stdout) -end @testset "Construction" begin @test isa(NoUnits, FreeUnits) @test typeof(𝐋) === Unitful.Dimensions{(Unitful.Dimension{:Length}(1),)} From c47347a2a027fc2dac3701d61f6d96769c8a5eba Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 17 Apr 2021 16:33:35 -0400 Subject: [PATCH 6/6] two-arg ustrip for arrays, safely reinterpret when possible --- src/utils.jl | 29 ++++++++++++++++++++++++++++- test/runtests.jl | 13 +++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index 0effccc0..85b1bda7 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -28,6 +28,27 @@ true @inline ustrip(u::Units, x) = ustrip(uconvert(u, x)) @inline ustrip(T::Type, u::Units, x) = convert(T, ustrip(u, x)) +""" + ustrip(u::Units, xs::AbstractArray{<:Quantity}) + +This broadcasts `ustrip.(u, xs)`, unless `xs isa StridedArray` whose units match `u`, +in which case it reinterprets, which saves making a copy. + +```jldoctest +julia> ustrip(u"m", [1, 2, 3]u"m") isa Base.ReinterpretArray{Int} # fast path +true + +julia> ustrip(u"m", [1, 2, 3]u"mm") == [1//1000, 2//1000, 3//1000] # mismatch requires slow path +true +``` +""" +ustrip(u::Units, xs::AbstractArray) = ustrip.(u, xs) +function ustrip(u::Units, xs::StridedArray{T}) where {T} + dimension(u) == dimension(T) || return ustrip.(u, xs) + isequal(promote(true * u, oneunit(T))...) || return ustrip.(u, xs) + return reinterpret(numtype(T), xs) +end + """ ustrip(x::Number) ustrip(x::Quantity) @@ -91,10 +112,16 @@ ustrip(A::Diagonal) = Diagonal(ustrip(A.diag)) ustrip(A::Bidiagonal) = Bidiagonal(ustrip(A.dv), ustrip(A.ev), ifelse(istriu(A), :U, :L)) ustrip(A::Tridiagonal) = Tridiagonal(ustrip(A.dl), ustrip(A.d), ustrip(A.du)) ustrip(A::SymTridiagonal) = SymTridiagonal(ustrip(A.dv), ustrip(A.ev)) - ustrip(A::Adjoint) = adjoint(ustrip(parent(A))) ustrip(A::Transpose) = transpose(ustrip(parent(A))) +ustrip(u::Units, A::Diagonal) = Diagonal(ustrip(u, A.diag)) +ustrip(u::Units, A::Bidiagonal) = Bidiagonal(ustrip(u, A.dv), ustrip(u, A.ev), ifelse(istriu(A), :U, :L)) +ustrip(u::Units, A::Tridiagonal) = Tridiagonal(ustrip(u, A.dl), ustrip(u, A.d), ustrip(u, A.du)) +ustrip(u::Units, A::SymTridiagonal) = SymTridiagonal(ustrip(u, A.dv), ustrip(u, A.ev)) +ustrip(u::Units, A::Adjoint) = adjoint(ustrip(u, parent(A))) +ustrip(u::Units, A::Transpose) = transpose(ustrip(u, parent(A))) + """ unit(x::Quantity{T,D,U}) where {T,D,U} unit(x::Type{Quantity{T,D,U}}) where {T,D,U} diff --git a/test/runtests.jl b/test/runtests.jl index 4e2824b9..321e35c0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1311,6 +1311,16 @@ end @test_deprecated ustrip([1,2]) @test ustrip.([1,2]) == [1,2] @test typeof(ustrip([1u"m", 2u"m"])) <: Base.ReinterpretArray{Int,1} + + # With target type + @test @inferred(ustrip(u"m", [1, 2]u"m")) == [1,2] + @test @inferred(ustrip(u"km", [1, 2]u"m")) == [1//1000, 2//1000] + @test typeof(ustrip(u"m", [1, 2]u"m")) <: Base.ReinterpretArray{Int,1} + @test typeof(ustrip(u"m/ms", [1, 2]*(u"km/s"))) <: Base.ReinterpretArray{Int,1} + + # Structured matrices + @test typeof(ustrip(adjoint([1,2]u"m"))) <: Adjoint{Int} + @test typeof(ustrip(transpose([1 2; 3 4]u"m"))) <: Transpose{Int} @test typeof(ustrip(Diagonal([1,2]u"m"))) <: Diagonal{Int} @test typeof(ustrip(Bidiagonal([1,2,3]u"m", [1,2]u"m", :U))) <: Bidiagonal{Int} @@ -1318,6 +1328,9 @@ end Tridiagonal{Int} @test typeof(ustrip(SymTridiagonal([1,2,3]u"m", [4,5]u"m"))) <: SymTridiagonal{Int} + + @test typeof(ustrip(u"m", adjoint([1,2]u"m"))) <: Adjoint{Int} + @test typeof(ustrip(u"m", Diagonal([1,2]u"m"))) <: Diagonal{Int} end @testset ">> Linear algebra" begin @test istril([1 1; 0 1]u"m") == false