From 367fa02bb809d4ae815f69e80325d17b4f4c183a Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Tue, 14 Nov 2023 14:08:29 +0100 Subject: [PATCH 1/4] improve perf of resizable_array --- src/resizable_array.jl | 22 ++++++++++++ test/resizable_array_tests.jl | 66 +++++++++++++++++++++++++++++++++++ 2 files changed, 88 insertions(+) diff --git a/src/resizable_array.jl b/src/resizable_array.jl index d2b2aed5..309f3c4c 100644 --- a/src/resizable_array.jl +++ b/src/resizable_array.jl @@ -90,6 +90,28 @@ function recursive_setindex!( return nothing end +function Base.isassigned(array::ResizableArray{T,V,N}, index::Integer...) where {T, V, N} + if length(index) !== N + return false + else + return recursive_isassigned(Val(N), array.data, index) + end +end + +function recursive_isassigned(::Val{N}, array, indices) where {N} + findex = first(indices) + tindices = Base.tail(indices) + if isassigned(array, findex) + return recursive_isassigned(Val(N - 1), @inbounds(array[findex]), tindices) + else + return false + end +end + +function recursive_isassigned(::Val{1}, array, index::Tuple{Integer}) + return isassigned(array, first(index)) +end + function getindex(array::ResizableArray{T,V,N}, index::UnitRange) where {T,V,N} return ResizableArray(array.data[index]) end diff --git a/test/resizable_array_tests.jl b/test/resizable_array_tests.jl index 523012f0..3f7a2b66 100644 --- a/test/resizable_array_tests.jl +++ b/test/resizable_array_tests.jl @@ -191,3 +191,69 @@ end end + +@testitem "isassigned" begin + import GraphPPL: ResizableArray + + @testset begin + s = ResizableArray(Ref, Val(1)) + + # In the beginning everything is not assigned + for i in 1:100 + @test !isassigned(s, i) + end + + # Assign some random indices + rindex = rand(1:100, 10) + + for i in rindex + s[i] = Ref(1) + end + + for i in 1:100 + if i ∉ rindex + @test !isassigned(s, i) + else + @test isassigned(s, i) + end + end + end + + @testset begin + for N in 1:5 + s = ResizableArray(Ref, Val(N)) + + for j in 1:N + @test !isassigned(s, ones(Int, j)...) + end + + s[ones(Int, N)...] = Ref(1) + + @test isassigned(s, ones(Int, N)...) + + s[10ones(Int, N)...] = Ref(1) + + @test isassigned(s, 10ones(Int, N)...) + + for k in 2:9 + @test !isassigned(s, k * ones(Int, N)...) + end + + end + end + + @testset begin + for N in 1:5, M in 1:5 + s = ResizableArray(Ref, Val(N)) + indices = CartesianIndex(ones(Int, N)...):CartesianIndex(M * ones(Int, N)...) + + for index in indices + @test !isassigned(s, index.I...) + s[index.I...] = Ref(1) + @test isassigned(s, index.I...) + end + + end + end + +end \ No newline at end of file From 0ec72cde375b44b745ea23052507cc6616aa496e Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Tue, 14 Nov 2023 14:52:33 +0100 Subject: [PATCH 2/4] attempt to fix missing_interfaces performance --- ext/GraphPPLDistributionsExt.jl | 32 +++++++++++++++++++++++++++----- src/graph_engine.jl | 31 +++++++++++++++++++------------ 2 files changed, 46 insertions(+), 17 deletions(-) diff --git a/ext/GraphPPLDistributionsExt.jl b/ext/GraphPPLDistributionsExt.jl index 5f04f8d4..31a24fc7 100644 --- a/ext/GraphPPLDistributionsExt.jl +++ b/ext/GraphPPLDistributionsExt.jl @@ -1,12 +1,34 @@ module GraphPPLDistributionsExt -using GraphPPL, Distributions + +using GraphPPL, Distributions, Static GraphPPL.NodeBehaviour(::Type{<:Distributions.Distribution}) = GraphPPL.Stochastic() + function GraphPPL.default_parametrization(::GraphPPL.Atomic, t::Type{<:Distributions.Distribution}, interface_values) - field_names = fieldnames(t) - @assert length(interface_values) == length(field_names) "Distribution $t has $(length(field_names)) fields $(field_names) but $(length(interface_values)) values were provided." - return NamedTuple{fieldnames(t)}(interface_values) + return distributions_ext_default_parametrization(t, distributions_ext_input_interfaces(t), interface_values) +end + +function distributions_ext_default_parametrization(t::Type{<:Distributions.Distribution}, ::GraphPPL.StaticInterfaces{interfaces}, interface_values) where {interfaces} + @assert length(interface_values) == length(interfaces) "Distribution $t has $(length(interfaces)) fields $(interfaces) but $(length(interface_values)) values were provided." + return NamedTuple{interfaces}(interface_values) +end + +function GraphPPL.interfaces(T::Type{<:Distributions.Distribution}, _) + return distributions_ext_interfaces(T) +end + +@generated function distributions_ext_input_interfaces(::Type{T}) where {T} + fnames = fieldnames(T) + return quote + GraphPPL.StaticInterfaces(($(map(QuoteNode, fnames)...), )) + end +end + +@generated function distributions_ext_interfaces(::Type{T}) where {T} + fnames = fieldnames(T) + return quote + GraphPPL.StaticInterfaces((:out, $(map(QuoteNode, fnames)...))) + end end -GraphPPL.interfaces(t::Type{<:Distributions.Distribution}, val) = (:out, fieldnames(t)...) end \ No newline at end of file diff --git a/src/graph_engine.jl b/src/graph_engine.jl index 21032614..e8e852ed 100644 --- a/src/graph_engine.jl +++ b/src/graph_engine.jl @@ -1013,6 +1013,13 @@ struct MixedArguments kwargs::NamedTuple end +""" +A structure that holds interfaces of a node in the type argument `I`. Used for dispatch. +""" +struct StaticInterfaces{I} end + +StaticInterfaces(I::Tuple) = StaticInterfaces{I}() + """ Placeholder function that is defined for all Composite nodes and is invoked when inferring what interfaces are missing when a node is called """ @@ -1032,22 +1039,22 @@ Returns the interfaces that are missing for a node. This is used when inferring # Returns - `missing_interfaces`: A `Vector` of the missing interfaces. """ -function missing_interfaces(node_type, val::StaticInt{N} where {N}, known_interfaces) - all_interfaces = GraphPPL.interfaces(node_type, val) - missing_interfaces = Base.setdiff(all_interfaces, keys(known_interfaces)) - return missing_interfaces +function missing_interfaces(fform, val, known_interfaces) + return missing_interfaces(interfaces(fform, val), StaticInterfaces(keys(known_interfaces))) end +function missing_interfaces(::StaticInterfaces{all_interfaces}, ::StaticInterfaces{present_interfaces}) where {all_interfaces, present_interfaces} + return StaticInterfaces(filter(interface -> interface ∉ present_interfaces, all_interfaces)) +end function prepare_interfaces(fform, lhs_interface, rhs_interfaces::NamedTuple) - missing_interface = GraphPPL.missing_interfaces( - fform, - static(length(rhs_interfaces) + 1), - rhs_interfaces, - ) - @assert length(missing_interface) == 1 lazy"Expected only one missing interface, got $missing_interface of length $(length(missing_interface)) (node $fform with interfaces $(keys(rhs_interfaces)))))" - missing_interface = first(missing_interface) - # TODO check if we can construct NamedTuples a bit faster somewhere. + missing_interface = missing_interfaces(fform, static(length(rhs_interfaces)) + static(1), rhs_interfaces) + return prepare_interfaces(missing_interface, lhs_interface, rhs_interfaces) +end + +function prepare_interfaces(::StaticInterfaces{I}, lhs_interface, rhs_interfaces::NamedTuple) where {I} + @assert length(I) == 1 lazy"Expected only one missing interface, got $missing_interface of length $(length(missing_interface)) (node $fform with interfaces $(keys(rhs_interfaces)))))" + missing_interface = first(I) return NamedTuple{(missing_interface, keys(rhs_interfaces)...)}(( lhs_interface, values(rhs_interfaces)..., From e4ca25540727abc8bf64c4f602097505fa4c2481 Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Tue, 14 Nov 2023 15:17:39 +0100 Subject: [PATCH 3/4] fix 2prev --- src/graph_engine.jl | 6 +++--- src/model_macro.jl | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/graph_engine.jl b/src/graph_engine.jl index e8e852ed..b6441a34 100644 --- a/src/graph_engine.jl +++ b/src/graph_engine.jl @@ -1023,8 +1023,8 @@ StaticInterfaces(I::Tuple) = StaticInterfaces{I}() """ Placeholder function that is defined for all Composite nodes and is invoked when inferring what interfaces are missing when a node is called """ -interfaces(any_f, ::StaticInt{1}) = (:out,) -interfaces(any_f, any_val) = (:out, :in) +interfaces(any_f, ::StaticInt{1}) = StaticInterfaces((:out,)) +interfaces(any_f, any_val) = StaticInterfaces((:out, :in)) """ missing_interfaces(node_type, val, known_interfaces) @@ -1053,7 +1053,7 @@ function prepare_interfaces(fform, lhs_interface, rhs_interfaces::NamedTuple) end function prepare_interfaces(::StaticInterfaces{I}, lhs_interface, rhs_interfaces::NamedTuple) where {I} - @assert length(I) == 1 lazy"Expected only one missing interface, got $missing_interface of length $(length(missing_interface)) (node $fform with interfaces $(keys(rhs_interfaces)))))" + @assert length(I) == 1 lazy"Expected only one missing interface, got $I of length $(length(I)) (node $fform with interfaces $(keys(rhs_interfaces)))))" missing_interface = first(I) return NamedTuple{(missing_interface, keys(rhs_interfaces)...)}(( lhs_interface, diff --git a/src/model_macro.jl b/src/model_macro.jl index 354d2674..e7f2e9ed 100644 --- a/src/model_macro.jl +++ b/src/model_macro.jl @@ -731,7 +731,7 @@ function get_boilerplate_functions(ms_name, ms_args, num_interfaces) function $ms_name end GraphPPL.interfaces(::typeof($ms_name), val) = error($error_msg * " $val keywords") GraphPPL.interfaces(::typeof($ms_name), ::GraphPPL.StaticInt{$num_interfaces}) = - Tuple($ms_args) + GraphPPL.StaticInterfaces(Tuple($ms_args)) GraphPPL.NodeType(::typeof($ms_name)) = GraphPPL.Composite() GraphPPL.NodeBehaviour(::typeof($ms_name)) = GraphPPL.Stochastic() end From 30a6988f48fc98b7d278371794bf86ce57ae4954 Mon Sep 17 00:00:00 2001 From: Wouter Nuijten Date: Tue, 14 Nov 2023 15:38:31 +0100 Subject: [PATCH 4/4] Fix test --- src/graph_engine.jl | 26 ++++++++++++++------------ test/model_macro_tests.jl | 20 ++++++++++---------- test/model_zoo.jl | 8 ++++---- 3 files changed, 28 insertions(+), 26 deletions(-) diff --git a/src/graph_engine.jl b/src/graph_engine.jl index b6441a34..b35bc08b 100644 --- a/src/graph_engine.jl +++ b/src/graph_engine.jl @@ -206,6 +206,15 @@ factor_nodes(model::Model) = Iterators.filter(node -> is_factor(model[node]), la variable_nodes(model::Model) = Iterators.filter(node -> is_variable(model[node]), labels(model)) +""" +A structure that holds interfaces of a node in the type argument `I`. Used for dispatch. +""" +struct StaticInterfaces{I} end + +StaticInterfaces(I::Tuple) = StaticInterfaces{I}() + + + struct ProxyLabel{T} name::Symbol index::T @@ -294,16 +303,16 @@ Graphs.edges(model::Model) = collect(Graphs.edges(model.graph)) MetaGraphsNext.label_for(model::Model, node_id::Int) = MetaGraphsNext.label_for(model.graph, node_id) -function retrieve_interface_position(interfaces, x::EdgeLabel, max_length::Int) +function retrieve_interface_position(interfaces::StaticInterfaces{I}, x::EdgeLabel, max_length::Int) where {I} index = x.index === nothing ? 0 : x.index - position = findfirst(isequal(x.name), interfaces) + position = findfirst(isequal(x.name), I) position = position === nothing ? begin - @warn(lazy"Interface $(x.name) not found in $interfaces") + @warn(lazy"Interface $(x.name) not found in $I") 0 end : position - return max_length * findfirst(isequal(x.name), interfaces) + index + return max_length * findfirst(isequal(x.name), I) + index end function __sortperm(model::Model, node::NodeLabel, edges::AbstractArray) @@ -1013,13 +1022,6 @@ struct MixedArguments kwargs::NamedTuple end -""" -A structure that holds interfaces of a node in the type argument `I`. Used for dispatch. -""" -struct StaticInterfaces{I} end - -StaticInterfaces(I::Tuple) = StaticInterfaces{I}() - """ Placeholder function that is defined for all Composite nodes and is invoked when inferring what interfaces are missing when a node is called """ @@ -1039,7 +1041,7 @@ Returns the interfaces that are missing for a node. This is used when inferring # Returns - `missing_interfaces`: A `Vector` of the missing interfaces. """ -function missing_interfaces(fform, val, known_interfaces) +function missing_interfaces(fform, val, known_interfaces::NamedTuple) return missing_interfaces(interfaces(fform, val), StaticInterfaces(keys(known_interfaces))) end diff --git a/test/model_macro_tests.jl b/test/model_macro_tests.jl index 8f3855df..e6f4e62b 100644 --- a/test/model_macro_tests.jl +++ b/test/model_macro_tests.jl @@ -1362,25 +1362,25 @@ end import GraphPPL: missing_interfaces, interfaces function abc end - GraphPPL.interfaces(::typeof(abc), ::StaticInt{3}) = [:in1, :in2, :out] + GraphPPL.interfaces(::typeof(abc), ::StaticInt{3}) = GraphPPL.StaticInterfaces((:in1, :in2, :out)) - @test missing_interfaces(abc, static(3), (in1 = :x, in2 = :y)) == [:out] - @test missing_interfaces(abc, static(3), (out = :y,)) == [:in1, :in2] - @test missing_interfaces(abc, static(3), Dict()) == [:in1, :in2, :out] + @test missing_interfaces(abc, static(3), (in1 = :x, in2 = :y)) == GraphPPL.StaticInterfaces((:out,)) + @test missing_interfaces(abc, static(3), (out = :y,)) == GraphPPL.StaticInterfaces((:in1, :in2)) + @test missing_interfaces(abc, static(3), NamedTuple()) == GraphPPL.StaticInterfaces((:in1, :in2, :out)) function xyz end - GraphPPL.interfaces(::typeof(xyz), ::StaticInt{0}) = [] - @test missing_interfaces(xyz, static(0), (in1 = :x, in2 = :y)) == [] + GraphPPL.interfaces(::typeof(xyz), ::StaticInt{0}) = GraphPPL.StaticInterfaces(()) + @test missing_interfaces(xyz, static(0), (in1 = :x, in2 = :y)) == GraphPPL.StaticInterfaces(()) function foo end - GraphPPL.interfaces(::typeof(foo), ::StaticInt{2}) = (:a, :b) - @test missing_interfaces(foo, static(2), (a = 1, b = 2)) == [] + GraphPPL.interfaces(::typeof(foo), ::StaticInt{2}) = GraphPPL.StaticInterfaces((:a, :b)) + @test missing_interfaces(foo, static(2), (a = 1, b = 2)) == GraphPPL.StaticInterfaces(()) function bar end - GraphPPL.interfaces(::typeof(bar), ::StaticInt{2}) = (:in1, :in2, :out) - @test missing_interfaces(bar, static(2), (in1 = 1, in2 = 2, out = 3, test = 4)) == [] + GraphPPL.interfaces(::typeof(bar), ::StaticInt{2}) = GraphPPL.StaticInterfaces((:in1, :in2, :out)) + @test missing_interfaces(bar, static(2), (in1 = 1, in2 = 2, out = 3, test = 4)) == GraphPPL.StaticInterfaces(()) end @testitem "keyword_expressions_to_named_tuple" begin diff --git a/test/model_zoo.jl b/test/model_zoo.jl index fad3ff26..01d8d4a3 100644 --- a/test/model_zoo.jl +++ b/test/model_zoo.jl @@ -26,8 +26,8 @@ GraphPPL.NodeBehaviour(::Type{NormalMeanPrecision}) = GraphPPL.Stochastic() GraphPPL.aliases(::Type{Normal}) = (Normal, NormalMeanVariance, NormalMeanPrecision) -GraphPPL.interfaces(::Type{NormalMeanVariance}, ::StaticInt{3}) = (:out, :μ, :σ) -GraphPPL.interfaces(::Type{NormalMeanPrecision}, ::StaticInt{3}) = (:out, :μ, :τ) +GraphPPL.interfaces(::Type{NormalMeanVariance}, ::StaticInt{3}) = GraphPPL.StaticInterfaces((:out, :μ, :σ)) +GraphPPL.interfaces(::Type{NormalMeanPrecision}, ::StaticInt{3}) = GraphPPL.StaticInterfaces((:out, :μ, :τ)) GraphPPL.factor_alias(::Type{Normal}, ::Val{(:μ, :σ)}) = NormalMeanVariance GraphPPL.factor_alias(::Type{Normal}, ::Val{(:μ, :τ)}) = NormalMeanPrecision @@ -36,8 +36,8 @@ struct GammaShapeScale end GraphPPL.aliases(::Type{Gamma}) = (Gamma, GammaShapeRate, GammaShapeScale) -GraphPPL.interfaces(::Type{GammaShapeRate}, ::StaticInt{3}) = (:out, :α, :β) -GraphPPL.interfaces(::Type{GammaShapeScale}, ::StaticInt{3}) = (:out, :α, :θ) +GraphPPL.interfaces(::Type{GammaShapeRate}, ::StaticInt{3}) = GraphPPL.StaticInterfaces((:out, :α, :β)) +GraphPPL.interfaces(::Type{GammaShapeScale}, ::StaticInt{3}) = GraphPPL.StaticInterfaces((:out, :α, :θ)) GraphPPL.factor_alias(::Type{Gamma}, ::Val{(:α, :β)}) = GammaShapeRate GraphPPL.factor_alias(::Type{Gamma}, ::Val{(:α, :θ)}) = GammaShapeScale