Skip to content

Commit

Permalink
Reduce time spent on type inference - was 180s for a solve(prob) with…
Browse files Browse the repository at this point in the history
… mixed units.

modified:   Project.toml   0.2.2
modified:   README.md      Update on zero..
modified:   src/MechGlueDiffEqBase.jl Improve inferrability, move zero to Untifu.jl
modified:   test/runtests.jl  Add test_4
modified:   test/test_3.jl    Add test zero-inferrability
new file:   test/test_4.jl    Unit tests, inferrable where possible
  • Loading branch information
hustf committed May 10, 2021
1 parent 28cbd0c commit b26f560
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 21 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MechGlueDiffEqBase"
uuid = "2532746b-52b5-4539-9431-8bb183ab067f"
authors = ["hustf <[email protected]> and contributors"]
version = "0.2.1"
version = "0.2.2"

[deps]
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
Expand Down
7 changes: 2 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
# MechGlueDiffEqBase
Glue code for making [DiffEqBase](https://github.com/SciML/DiffEqBase.jl) work with units.

It also includes glue code for [RecursiveArrayTools](https://github.com/SciML/RecursiveArrayTools.jl), which enables type-stable solution of equations with mixed units. We define zero for mixed-dimension vectors.

This defines how to calculate the vector norm when the vector is given in units compatible with [Unitfu.jl](https://github.com/hustf/Unitfu.jl), from registy [M8](https://github.com/hustf/M8). The differential equation algorithms expects the norm to be unitless, as can be seen in e.g. step size estimators:

It also used to include glue code for [RecursiveArrayTools](https://github.com/SciML/RecursiveArrayTools.jl), which enables type-stable solution of equations with mixed units. This may not be needed after a change to Unitfu.jl v1.7.7, but the depencency is kept until further upstream testing.

err_scaled = **error** / (**abstol** + norm(u) * **reltol**)

where **bold** indicates unitful objects.

The functions are adaptions of corresponding code from [DiffEqBase](https://github.com/SciML/DiffEqBase.jl/blob/6bb8830711e729ef513f2b1beb95853e4a691375/src/init.jl).



27 changes: 18 additions & 9 deletions src/MechGlueDiffEqBase.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
module MechGlueDiffEqBase
import Unitfu: AbstractQuantity, Quantity, ustrip, norm
import Unitfu: AbstractQuantity, Quantity, ustrip, norm, unit
import DiffEqBase: value, ODE_DEFAULT_NORM, UNITLESS_ABS2, zero
import DiffEqBase: calculate_residuals, @muladd
using RecursiveArrayTools
export value, ODE_DEFAULT_NORM, UNITLESS_ABS2, Unitfu, AbstractQuantity, Quantity
export norm, ArrayPartition

Base.zero(A::ArrayPartition{<:AbstractQuantity{T},S}) where {T<:Number,S} = zero.(A)
export norm , ArrayPartition # Probably no longer necessary with changes in Unitfu 1.7.7. We could perhaps drop this depencency.

# This is identical to what DiffEqBase defines for Unitful
function value(x::Type{AbstractQuantity{T,D,U}}) where {T,D,U}
Expand All @@ -26,14 +24,25 @@ end
@inline function ODE_DEFAULT_NORM(u::Array{<:AbstractQuantity,N},t) where {N}
sqrt(sum(x->ODE_DEFAULT_NORM(x[1],x[2]),zip((value(x) for x in u),Iterators.repeated(t))) / length(u))
end

# This is identical to what DiffEqBase defines for Unitful

@inline function ODE_DEFAULT_NORM(u::AbstractQuantity, t)
abs(ustrip(u))
end
# This is slightly different from what DiffEqBase defines for Unitful
@inline function UNITLESS_ABS2(x::AbstractQuantity)
real(abs2(x) / (oneunit(x)^2))

@inline function UNITLESS_ABS2(u::AbstractArray{<:AbstractQuantity,N} where N)
map(UNITLESS_ABS2, u)
end
@inline function UNITLESS_ABS2(u::AbstractArray{Quantity{T},N}) where {N, T}
map(UNITLESS_ABS2, u)
end

@inline function UNITLESS_ABS2(x::T) where T <: AbstractQuantity
xul = x / oneunit(T)
abs2(xul)
end
@inline function UNITLESS_ABS2(x::Quantity{T, D, U}) where {T, D, U}
xul = x / oneunit(Quantity{T, D, U})
abs2(xul)::T
end

end
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Test
include("test_1.jl")
include("test_2.jl")
include("test_3.jl")
include("test_3.jl")
include("test_4.jl")
18 changes: 13 additions & 5 deletions test/test_3.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ using MechGlueDiffEqBase
using MechanicalUnits: @import_expand, dimension, NoDims,
import MechanicalUnits: g, g⁻¹
@import_expand(km, N, s, m, km, kg, °, inch)
using DiffEqBase, OrdinaryDiffEq # Unitfu,
using DiffEqBase, OrdinaryDiffEq
using OrdinaryDiffEq: OrdinaryDiffEqAdaptiveAlgorithm, OrdinaryDiffEqCompositeAlgorithm, DAEAlgorithm, FunctionMap,LinearExponential



@testset "Initial checks ArrayPartition" begin
r0 = [1131.340, -2282.343, 6672.423]km
Expand Down Expand Up @@ -57,9 +60,17 @@ end
print("\n")
end

@testset "zero of ArrayPartition with mixed units" begin
@testset "zero ArrayPartition" begin
@test zero.([0.0ms⁻¹, 0.0ms⁻¹, 909.3266739736605ms⁻², 525.0ms⁻²],) == [0.0ms⁻¹, 0.0ms⁻¹, 0.0ms⁻², 0.0ms⁻²]
@test zero(ArrayPartition([0.0ms⁻¹, 0.0ms⁻¹, 909.3266739736605ms⁻², 525.0ms⁻²],)) == [0.0ms⁻¹, 0.0ms⁻¹, 0.0ms⁻², 0.0ms⁻²]
@test @inferred(zero(ArrayPartition([0.0ms⁻¹, 0.0ms⁻¹, 909.3266739736605ms⁻², 525.0ms⁻²],))) == [0.0ms⁻¹, 0.0ms⁻¹, 0.0ms⁻², 0.0ms⁻²]
@test typeof(zero(ArrayPartition([0.0ms⁻¹, 0.0ms⁻¹, 909.3266739736605ms⁻², 525.0ms⁻²],))) ==
ArrayPartition{Quantity{Float64, D, U} where {D, U}, Tuple{Vector{Quantity{Float64, D, U} where {D, U}}}}
end



@testset "ArrayPartition with mixed units" begin
α₀() = 30°
x₀() = 0.0m
y₀() = 0.0m
Expand Down Expand Up @@ -97,9 +108,6 @@ end
AutoVern8(Rodas5(autodiff=false)),
AutoVern9(Rodas5(autodiff=false))])


using OrdinaryDiffEq: OrdinaryDiffEqAdaptiveAlgorithm, OrdinaryDiffEqCompositeAlgorithm, DAEAlgorithm, FunctionMap,LinearExponential

function requires_stepsize(alg)
adaptive = OrdinaryDiffEq.isadaptive(alg)
(((!(typeof(alg) <: OrdinaryDiffEqAdaptiveAlgorithm) && !(typeof(alg) <: OrdinaryDiffEqCompositeAlgorithm) && !(typeof(alg) <: DAEAlgorithm)) || !adaptive) ) && !(typeof(alg) <: Union{FunctionMap,LinearExponential})
Expand Down
75 changes: 75 additions & 0 deletions test/test_4.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Check type stability of ODE_DEFAULT_NORM with mixed units and with ArrayPartition
using Test
using MechGlueDiffEqBase
import MechGlueDiffEqBase: zero_collection
using MechanicalUnits: @import_expand, dimension, NoDims,
import MechanicalUnits: g, g⁻¹
@import_expand(km, N, s, m, km, kg, °, inch)
using DiffEqBase, OrdinaryDiffEq

@testset "Inferrable norm Unitless ArrayPartition" begin
r0ul = [1131.340, -2282.343, 6672.423]
v0ul = [-5.64305, 4.30333, 2.42879]
rv0ul = @inferred ArrayPartition(r0ul,v0ul)
@test @inferred(ODE_DEFAULT_NORM(rv0ul, 0.0)) === 2915.770473301504
end
@testset "Inferrable norm ArrayPartition units" begin
r0 = [1131.340, -2282.343, 6672.423]km
v0 = [-5.64305, 4.30333, 2.42879]km/s
rv0 = @inferred ArrayPartition(r0, v0)
@test @inferred(ODE_DEFAULT_NORM(rv0, 0.0s)) === 41.02536038842316
end
@testset "Inferrable norm ArrayPartition mixed units" begin
r0 = [1.0km, 2.0km, 3.0m/s, 4.0m/s]
v0 = [1.0km/s, 2.0km/s, 3.0m/s², 4m/s²]
rv0 = @inferred ArrayPartition(r0, v0)
@test @inferred(ODE_DEFAULT_NORM(rv0, 0.0s)) === 1.5811388300841898
end


@testset "Inferrable zero Unitless ArrayPartition" begin
r0ul = [1131.340, -2282.343, 6672.423]
v0ul = [-5.64305, 4.30333, 2.42879]
rv0ul = ArrayPartition(r0ul, v0ul)
@test @inferred(zero(rv0ul)) == ArrayPartition([0.0, 0.0, 0.0], [0.0, 0.0, 0.0])
@test ArrayPartition{Float64, Tuple{Vector{Float64}, Vector{Float64}}} == typeof(zero(rv0ul))
end
@testset "Inferrable zero compatible units" begin
r0 = [1131.340, -2282.343, 6672.423]km
r1 = [1km, 2.0m]
@test @inferred(zero(r0)) == [0.0, 0.0, 0.0]km
@test @inferred(zero(r1)) == [0.0, 0.0]km
rv0 = ArrayPartition(r0)
@test @inferred(zero(rv0)) == [0.0, 0.0, 0.0]km
@test typeof(zero(rv0)) == typeof(rv0)
rv1 = ArrayPartition(r1)
@test @inferred(zero(rv1)) == [0.0, 0.0]km
@test typeof(zero(rv0)) == typeof(rv0)
end
@testset "(Inferrable) zero ArrayPartition mixed units" begin
r0 = [1.0km, 2km, 3m/s, 4m/s]
v0 = [1.0km/s, 2.0km/s, 3.0m/s², 4m/s²]
# Not inferrable, but inferred return type is now AbstractVector{var"#s831"} where var"#s831", not Any
@test zero(r0) == [0.0km, 0.0km, 0.0m/s, 0.0m/s]
# ArrayPartition fixes that
rv0 = @inferred ArrayPartition(r0)
zer = @inferred zero(rv0)
@test zer == [0.0km, 0.0km, 0.0m/s, 0.0m/s]
@test typeof(zer) === typeof(rv0)
end

@testset "Inferrable UNITLESS_ABS2 Unitless ArrayPartition" begin
r0ul = [1131.340, -2282.343, 6672.423]
v0ul = [-5.64305, 4.30333, 2.42879]
rv0ul = ArrayPartition(r0ul,v0ul)
@test @inferred(UNITLESS_ABS2(rv0ul)) == 5.101030471786125e7
end

@testset "Inferrable UNITLESS_ABS2 ArrayPartition mixed units" begin
r0 = [1.0km, 2.0km, 3.0m/s, 4.0m/s]
v0 = [1.0km/s, 2.0km/s, 3.0m/s², 4m/s²]
rv0 = ArrayPartition(r0, v0)
@test @inferred(UNITLESS_ABS2(1.0km)) == 1.0
@test @inferred(UNITLESS_ABS2(r0)) == [1.0, 4.0, 9.0, 16.0]
@test @inferred(UNITLESS_ABS2(rv0)) == 2 .* [1.0, 4.0, 9.0, 16.0]
end

0 comments on commit b26f560

Please sign in to comment.