diff --git a/benchmark/model_creation.jl b/benchmark/model_creation.jl index 61784278..3574ae8b 100644 --- a/benchmark/model_creation.jl +++ b/benchmark/model_creation.jl @@ -9,5 +9,9 @@ function benchmark_model_creation() for i in 10 .^ range(1, stop=3) SUITE["create HGF of depth $i"] = @benchmarkable create_hgf($i) end + for i in 10 .^ range(2, stop=6) + n_nodes = 10^i + SUITE["create model with array of length $i"] = @benchmarkable create_longarray($n_nodes) + end return SUITE end \ No newline at end of file diff --git a/src/graph_engine.jl b/src/graph_engine.jl index 497ea5b5..e40bcb44 100644 --- a/src/graph_engine.jl +++ b/src/graph_engine.jl @@ -70,6 +70,20 @@ to_symbol(label::NodeLabel) = Symbol(String(label.name) * "_" * string(label.glo Base.show(io::IO, label::NodeLabel) = print(io, label.name, "_", label.global_counter) +struct EdgeLabel + name::Symbol + index::Union{Int, Nothing} +end + +getname(label::EdgeLabel) = label.name +getname(labels::Tuple) = map(group -> getname(group), labels) + +to_symbol(label::EdgeLabel) = to_symbol(label, label.index) +to_symbol(label::EdgeLabel, ::Nothing) = label.name +to_symbol(label::EdgeLabel, ::Int64) = Symbol(string(label.name) * "[" * string(label.index) * "]") + +Base.show(io::IO, label::EdgeLabel) = print(io, to_symbol(label)) + mutable struct VariableNodeOptions value::Any functional_form::Any @@ -105,7 +119,7 @@ meta(options::VariableNodeOptions) = options.meta Data associated with a variable node in a probabilistic graphical model. """ -struct VariableNodeData +mutable struct VariableNodeData name::Symbol options::VariableNodeOptions index::Any @@ -167,6 +181,7 @@ mutable struct FactorNodeData context::Any factorization_constraint::Any options::FactorNodeOptions + neighbors::NTuple{N, Tuple{NodeLabel, EdgeLabel}} where {N} end fform(node::FactorNodeData) = node.fform @@ -226,20 +241,6 @@ Base.last(label::ProxyLabel) = last(label.proxied, label) Base.last(proxied::ProxyLabel, ::ProxyLabel) = last(proxied) Base.last(proxied, ::ProxyLabel) = proxied -struct EdgeLabel - name::Symbol - index::Union{Int, Nothing} -end - -getname(label::EdgeLabel) = label.name -getname(labels::Tuple) = map(group -> getname(group), labels) - -to_symbol(label::EdgeLabel) = to_symbol(label, label.index) -to_symbol(label::EdgeLabel, ::Nothing) = label.name -to_symbol(label::EdgeLabel, ::Int64) = Symbol(string(label.name) * "[" * string(label.index) * "]") - -Base.show(io::IO, label::EdgeLabel) = print(io, to_symbol(label)) - Model(graph::MetaGraph) = Model(graph, Base.RefValue(0)) Base.setindex!(model::Model, val::NodeData, key::NodeLabel) = Base.setindex!(model.graph, val, key) @@ -247,7 +248,8 @@ Base.setindex!(model::Model, val::EdgeLabel, src::NodeLabel, dst::NodeLabel) = B Base.getindex(model::Model) = Base.getindex(model.graph) Base.getindex(model::Model, key::NodeLabel) = Base.getindex(model.graph, key) Base.getindex(model::Model, src::NodeLabel, dst::NodeLabel) = Base.getindex(model.graph, src, dst) -Base.getindex(model::Model, keys::AbstractArray{NodeLabel}) = [model[key] for key in keys] +Base.getindex(model::Model, keys::AbstractArray{NodeLabel}) = map(key -> model[key], keys) +Base.getindex(model::Model, keys::NTuple{N, NodeLabel}) where {N} = collect(map(key -> model[key], keys)) Base.getindex(model::Model, keys::Base.Generator) = [model[key] for key in keys] @@ -272,59 +274,18 @@ increase_count(model::Model) = Base.setproperty!(model, :counter, model.counter Graphs.nv(model::Model) = Graphs.nv(model.graph) Graphs.ne(model::Model) = Graphs.ne(model.graph) Graphs.edges(model::Model) = Graphs.edges(model.graph) -MetaGraphsNext.label_for(model::Model, node_id::Int) = MetaGraphsNext.label_for(model.graph, node_id) - -function retrieve_interface_position(interfaces::StaticInterfaces{I}, x::EdgeLabel, max_length::Int) where {I} - index = x.index === nothing ? 0 : x.index - position = findfirst(isequal(x.name), I) - position = position === nothing ? begin - @warn(lazy"Interface $(x.name) not found in $I") - 0 - end : position - return max_length * findfirst(isequal(x.name), I) + index -end - -function __sortperm(model::Model, node::NodeLabel, edges::AbstractArray) - fform = model[node].fform - indices = [e.index for e in edges] - names = unique([e.name for e in edges]) - interfaces = GraphPPL.interfaces(fform, static(length(names))) - max_length = any(x -> x !== nothing, indices) ? maximum(indices[indices .!= nothing]) : 1 - perm = sortperm(edges, by = (x -> retrieve_interface_position(interfaces, x, max_length))) - return perm -end - -__get_neighbors(model::Model, node::NodeLabel) = Iterators.map(neighbor -> label_for(model, neighbor), MetaGraphsNext.neighbors(model.graph, code_for(model.graph, node))) - -__neighbors(model::Model, node::NodeLabel; sorted = false) = __neighbors(model, node, model[node]; sorted = sorted) -__neighbors(model::Model, node::NodeLabel, node_data::VariableNodeData; sorted = false) = __get_neighbors(model, node) -__neighbors(model::Model, node::NodeLabel, node_data::FactorNodeData; sorted = false) = __neighbors(model, node, static(sorted)) - -__neighbors(model::Model, node::NodeLabel, ::False) = __get_neighbors(model, node) -function __neighbors(model::Model, node::NodeLabel, ::True) - neighbors = collect(__get_neighbors(model, node)) - edges = __get_edges(model, node, neighbors) - perm = __sortperm(model, node, edges) - return neighbors[perm] -end - -Graphs.neighbors(model::Model, node::NodeLabel; sorted = false) = __neighbors(model, node; sorted = sorted) -Graphs.neighbors(model::Model, nodes::AbstractArray; sorted = false) = reduce(union, Graphs.neighbors.(Ref(model), nodes; sorted = sorted)) -Graphs.vertices(model::Model) = MetaGraphsNext.vertices(model.graph) -MetaGraphsNext.labels(model::Model) = MetaGraphsNext.labels(model.graph) +Graphs.neighbors(model::Model, node::NodeLabel) = Graphs.neighbors(model, node, model[node]) +Graphs.neighbors(model::Model, node::NodeLabel, nodedata::FactorNodeData) = map(neighbor -> neighbor[1], nodedata.neighbors) +Graphs.neighbors(model::Model, node::NodeLabel, nodedata::VariableNodeData) = MetaGraphsNext.neighbor_labels(model.graph, node) +Graphs.neighbors(model::Model, nodes::AbstractArray{<:NodeLabel}) = Iterators.flatten(map(node -> Graphs.neighbors(model, node), nodes)) -__get_edges(model::Model, node::NodeLabel, neighbors) = getindex.(Ref(model), Ref(node), neighbors) -__edges(model::Model, node::NodeLabel, node_data::VariableNodeData; sorted = false) = __get_edges(model, node, __get_neighbors(model, node)) -__edges(model::Model, node::NodeLabel, node_data::FactorNodeData; sorted = false) = __edges(model, node, static(sorted)) -__edges(model::Model, node::NodeLabel, ::False) = __get_edges(model, node, __get_neighbors(model, node)) -function __edges(model::Model, node::NodeLabel, ::True) - neighbors = __get_neighbors(model, node) - edges = __get_edges(model, node, neighbors) - perm = __sortperm(model, node, edges) - return edges[perm] +Graphs.edges(model::Model, node::NodeLabel) = Graphs.edges(model, node, model[node]) +Graphs.edges(model::Model, node::NodeLabel, nodedata::FactorNodeData) = map(neighbor -> neighbor[2], nodedata.neighbors) +function Graphs.edges(model::Model, node::NodeLabel, nodedata::VariableNodeData) + return Tuple(model[node, dst] for dst in MetaGraphsNext.neighbor_labels(model.graph, node)) end -Graphs.edges(model::Model, node::NodeLabel; sorted = false) = __edges(model, node, model[node]; sorted = sorted) +Graphs.edges(model::Model, nodes::AbstractArray{<:NodeLabel}) = Tuple(Iterators.flatten(map(node -> Graphs.edges(model, node), nodes))) abstract type AbstractModelFilterPredicate end @@ -777,7 +738,7 @@ Returns: function add_atomic_factor_node!(model::Model, context::Context, fform; __options__ = FactorNodeOptions()) factornode_id = generate_factor_nodelabel(context, fform) factornode_label = generate_nodelabel(model, fform) - model[factornode_label] = FactorNodeData(fform, context, nothing, __options__) + model[factornode_label] = FactorNodeData(fform, context, nothing, __options__, ()) context.factor_nodes[factornode_id] = factornode_label return factornode_label end @@ -811,7 +772,9 @@ end iterator(interfaces::NamedTuple) = zip(keys(interfaces), values(interfaces)) function add_edge!(model::Model, factor_node_id::NodeLabel, variable_node_id::Union{ProxyLabel, NodeLabel}, interface_name::Symbol; index = nothing) - model.graph[unroll(variable_node_id), factor_node_id] = EdgeLabel(interface_name, index) + label = EdgeLabel(interface_name, index) + model[factor_node_id].neighbors = (model[factor_node_id].neighbors..., (unroll(variable_node_id), label)) + model.graph[unroll(variable_node_id), factor_node_id] = label end function add_edge!(model::Model, factor_node_id::NodeLabel, variable_nodes::Union{AbstractArray, Tuple, NamedTuple}, interface_name::Symbol; index = 1) @@ -824,7 +787,7 @@ increase_index(any) = 1 increase_index(x::AbstractArray) = length(x) function add_factorization_constraint!(model::Model, factor_node_id::NodeLabel) - out_degree = outdegree(model.graph, code_for(model.graph, factor_node_id)) + out_degree = length(model[factor_node_id].neighbors) constraint = BitSetTuple(out_degree) set_factorization_constraint!(model[factor_node_id], constraint) end diff --git a/test/constraints_engine_tests.jl b/test/constraints_engine_tests.jl index 4a216810..024bc480 100644 --- a/test/constraints_engine_tests.jl +++ b/test/constraints_engine_tests.jl @@ -503,7 +503,7 @@ end ctx = GraphPPL.getcontext(model) node = ctx[NormalMeanVariance, 2] materialize_constraints!(model, node) - @test get_constraint_names(factorization_constraint(model[node])) == ((:μ, :σ, :out),) + @test get_constraint_names(factorization_constraint(model[node])) == ((:out, :μ, :σ),) materialize_constraints!(model, ctx[NormalMeanVariance, 1]) @test get_constraint_names(factorization_constraint(model[ctx[NormalMeanVariance, 1]])) == ((:out,), (:μ,), (:σ,)) @@ -513,7 +513,7 @@ end node = ctx[NormalMeanVariance, 2] GraphPPL.save_constraint!(model[node], BitSetTuple([[1], [2, 3], [2, 3]])) materialize_constraints!(model, node) - @test get_constraint_names(factorization_constraint(model[node])) == ((:μ,), (:σ, :out)) + @test get_constraint_names(factorization_constraint(model[node])) == ((:out,), (:μ, :σ)) # Test 3: Check that materialize_constraints! throws if the constraint is not a valid partition model = create_terminated_model(simple_model) @@ -653,7 +653,7 @@ end (ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:w, 2, context),)), ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:w, 3, context),))) ) @test GraphPPL.is_applicable(neighbors, constraint) - @test GraphPPL.convert_to_bitsets(model, normal_node, neighbors, constraint) == BitSetTuple([[1, 3], [2, 3], [1, 2, 3]]) + @test GraphPPL.convert_to_bitsets(model, normal_node, neighbors, constraint) == BitSetTuple([[1, 2, 3], [1, 2], [1, 3]]) end let constraint = ResolvedFactorizationConstraint( @@ -668,7 +668,7 @@ end (ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:w, SplittedRange(2, 3), context),)),) ) @test GraphPPL.is_applicable(neighbors, constraint) - @test GraphPPL.convert_to_bitsets(model, normal_node, neighbors, constraint) == BitSetTuple([[1, 3], [2, 3], [1, 2, 3]]) + @test GraphPPL.convert_to_bitsets(model, normal_node, neighbors, constraint) == BitSetTuple([[1, 2, 3], [1, 2], [1, 3]]) end let constraint = ResolvedFactorizationConstraint( @@ -690,9 +690,9 @@ end ) ) @test GraphPPL.is_applicable(neighbors, constraint) - @test GraphPPL.convert_to_bitsets(model, normal_node, neighbors, constraint) == BitSetTuple([[1], [2, 3], [2, 3]]) + @test GraphPPL.convert_to_bitsets(model, normal_node, neighbors, constraint) == BitSetTuple([[1, 3], [2], [1, 3]]) apply!(model, normal_node, constraint) - @test GraphPPL.factorization_constraint(model[normal_node]) == BitSetTuple([[1], [2, 3], [2, 3]]) + @test GraphPPL.factorization_constraint(model[normal_node]) == BitSetTuple([[1, 3], [2], [1, 3]]) end let constraint = ResolvedFactorizationConstraint( @@ -703,7 +703,7 @@ end ) ) @test GraphPPL.is_applicable(neighbors, constraint) - @test GraphPPL.convert_to_bitsets(model, normal_node, neighbors, constraint) == BitSetTuple([[1, 2], [1, 2], [3]]) + @test GraphPPL.convert_to_bitsets(model, normal_node, neighbors, constraint) == BitSetTuple([[1], [2, 3], [2, 3]]) end model = create_terminated_model(multidim_array) @@ -776,7 +776,7 @@ end @test GraphPPL.is_applicable(neighbors, constraint) # This test should throw since we cannot resolve the constraint - @test_broken GraphPPL.convert_to_bitsets(model, normal_node, neighbors, constraint) + @test_broken (try GraphPPL.convert_to_bitsets(model, normal_node, neighbors, constraint) catch e; e; end) isa Exception end end @@ -841,7 +841,7 @@ end @test message_constraint(model[node]) === nothing end @test getname(factorization_constraint(model[ctx[NormalMeanVariance, 1]])) == ((:out,), (:μ,), (:σ,)) - @test getname(factorization_constraint(model[ctx[NormalMeanVariance, 2]])) == ((:μ, :out), (:σ,)) + @test getname(factorization_constraint(model[ctx[NormalMeanVariance, 2]])) == ((:out, :μ), (:σ,)) # Test constriants macro with nested model model = create_terminated_model(outer) @@ -868,7 +868,7 @@ end @test fform_constraint(model[ctx[:y]]) == NormalMeanVariance() for node in filter(GraphPPL.as_node(NormalMeanVariance) & GraphPPL.as_context(inner_inner), model) - @test getname(factorization_constraint(model[node])) == ((:μ, :σ), (:out,)) + @test getname(factorization_constraint(model[node])) == ((:out,), (:μ, :σ)) end # Test with specifying specific submodel @@ -881,9 +881,9 @@ end end apply!(model, constraints) - @test getname(factorization_constraint(model[ctx[child_model, 1][NormalMeanVariance, 1]])) == ((:μ, :out), (:σ,)) + @test getname(factorization_constraint(model[ctx[child_model, 1][NormalMeanVariance, 1]])) == ((:out, :μ), (:σ,)) for i in 2:99 - @test getname(factorization_constraint(model[ctx[child_model, i][NormalMeanVariance, 1]])) == ((:μ, :out, :σ),) + @test getname(factorization_constraint(model[ctx[child_model, i][NormalMeanVariance, 1]])) == ((:out, :μ, :σ),) end # Test with specifying general submodel @@ -896,9 +896,9 @@ end end apply!(model, constraints) - @test getname(factorization_constraint(model[ctx[child_model, 1][NormalMeanVariance, 1]])) == ((:μ, :out), (:σ,)) + @test getname(factorization_constraint(model[ctx[child_model, 1][NormalMeanVariance, 1]])) == ((:out, :μ), (:σ,)) for node in filter(GraphPPL.as_node(NormalMeanVariance) & GraphPPL.as_context(child_model), model) - @test getname(factorization_constraint(model[node])) == ((:μ, :out), (:σ,)) + @test getname(factorization_constraint(model[node])) == ((:out, :μ), (:σ,)) end # Test with ambiguous constraints diff --git a/test/graph_engine_tests.jl b/test/graph_engine_tests.jl index 4800e6e3..f738e4b5 100644 --- a/test/graph_engine_tests.jl +++ b/test/graph_engine_tests.jl @@ -210,6 +210,18 @@ end end end +@testitem "NodeLabel properties" begin + import GraphPPL: NodeLabel + + x = NodeLabel(:x, 1) + @test x[1] == x + @test length(x) === 1 + @test GraphPPL.to_symbol(x) === :x_1 + + y = NodeLabel(:y, 2) + @test x < y +end + @testitem "getname(::NodeLabel)" begin import GraphPPL: ResizableArray, NodeLabel, getname @@ -235,11 +247,11 @@ end @test_throws MethodError model[0] = 1 - @test_throws MethodError model["string"] = VariableNodeData(:x, VariableNodeOptions()) + @test_throws MethodError model["string"] = VariableNodeData(:x, VariableNodeOptions(), nothing, nothing, nothing) model[NodeLabel(:x, 2)] = VariableNodeData(:x, VariableNodeOptions(), nothing, nothing, nothing) @test nv(model) == 2 && ne(model) == 0 - model[NodeLabel(sum, 3)] = FactorNodeData(sum, getcontext(model), nothing, FactorNodeOptions()) + model[NodeLabel(sum, 3)] = FactorNodeData(sum, getcontext(model), nothing, FactorNodeOptions(), ()) @test nv(model) == 3 && ne(model) == 0 end @@ -313,34 +325,40 @@ end end @testitem "edges" begin - import GraphPPL: edges, create_model, VariableNodeData, NodeLabel, EdgeLabel, getname, VariableNodeOptions + import GraphPPL: edges, create_model, VariableNodeData, NodeLabel, EdgeLabel, getname, VariableNodeOptions, add_edge!, FactorNodeOptions, FactorNodeData # Test 1: Test getting all edges from a model model = create_model() - model[NodeLabel(:a, 1)] = VariableNodeData(:a, VariableNodeOptions(), nothing, nothing, nothing) - model[NodeLabel(:b, 2)] = VariableNodeData(:b, VariableNodeOptions(), nothing, nothing, nothing) - model[NodeLabel(:a, 1), NodeLabel(:b, 2)] = EdgeLabel(:edge, 1) + a = NodeLabel(:a, 1) + b = NodeLabel(:b, 2) + model[a] = VariableNodeData(:a, VariableNodeOptions(), nothing, nothing, nothing) + model[b] = FactorNodeData(sum, GraphPPL.Context(), nothing, FactorNodeOptions(), ()) + add_edge!(model, b, a, :edge; index = 1) @test length(edges(model)) == 1 - model[NodeLabel(:c, 2)] = VariableNodeData(:b, VariableNodeOptions(), nothing, nothing, nothing) - model[NodeLabel(:a, 1), NodeLabel(:c, 2)] = EdgeLabel(:edge, 2) + c = NodeLabel(:c, 2) + model[NodeLabel(:c, 2)] = FactorNodeData(sum, GraphPPL.Context(), nothing, FactorNodeOptions(), ()) + add_edge!(model, c, a, :edge; index = 2) @test length(edges(model)) == 2 # Test 2: Test getting all edges from a model with a specific node - @test getname.(edges(model, NodeLabel(:a, 1))) == [:edge, :edge] - @test getname.(edges(model, NodeLabel(:b, 2))) == [:edge] - @test getname.(edges(model, NodeLabel(:c, 2))) == [:edge] + @test getname.(edges(model, a)) == (:edge, :edge) + @test getname.(edges(model, b)) == (:edge,) + @test getname.(edges(model, c)) == (:edge,) + @test getname.(edges(model, [a, b])) == (:edge, :edge, :edge) end @testitem "neighbors(::Model, ::NodeData)" begin include("model_zoo.jl") - import GraphPPL: create_model, getcontext, neighbors, VariableNodeData, NodeLabel, EdgeLabel, getname, ResizableArray, VariableNodeOptions + import GraphPPL: create_model, getcontext, neighbors, VariableNodeData, NodeLabel, EdgeLabel, getname, ResizableArray, VariableNodeOptions, add_edge!, FactorNodeData, FactorNodeOptions model = create_model() __context__ = getcontext(model) - model[NodeLabel(:a, 1)] = VariableNodeData(:a, VariableNodeOptions(), nothing, nothing, __context__) - model[NodeLabel(:b, 2)] = VariableNodeData(:b, VariableNodeOptions(), nothing, nothing, __context__) - model[NodeLabel(:a, 1), NodeLabel(:b, 2)] = EdgeLabel(:edge, 1) + a = NodeLabel(:a, 1) + b = NodeLabel(:b, 2) + model[a] = FactorNodeData(sum, GraphPPL.Context(), nothing, FactorNodeOptions(), ()) + model[b] = VariableNodeData(:b, VariableNodeOptions(), nothing, nothing, __context__) + add_edge!(model, a, b, :edge; index = 1) @test collect(neighbors(model, NodeLabel(:a, 1))) == [NodeLabel(:b, 2)] model = create_model() @@ -349,24 +367,25 @@ end b = ResizableArray(NodeLabel, Val(1)) for i in 1:3 a[i] = NodeLabel(:a, i) - model[a[i]] = VariableNodeData(:a, VariableNodeOptions(), i, nothing, __context__) + model[a[i]] = FactorNodeData(sum, GraphPPL.Context(), nothing, FactorNodeOptions(), ()) b[i] = NodeLabel(:b, i) model[b[i]] = VariableNodeData(:b, VariableNodeOptions(), i, nothing, __context__) - model[a[i], b[i]] = EdgeLabel(:edge, i) + add_edge!(model, a[i], b[i], :edge; index = i) + end + for n in b + @test n ∈ neighbors(model, a) end - @test neighbors(model, a; sorted = true) == [b[1], b[2], b[3]] - # Test 2: Test getting sorted neighbors model = create_terminated_model(simple_model) ctx = getcontext(model) node = first(neighbors(model, ctx[:z])) # Normal node we're investigating is the only neighbor of `z` in the graph. - @test getname.(neighbors(model, node; sorted = true)) == [:z, :x, :y] + @test getname.(neighbors(model, node)) == (:z, :x, :y) # Test 3: Test getting sorted neighbors when one of the edge indices is nothing model = create_terminated_model(vector_model) ctx = getcontext(model) node = first(neighbors(model, ctx[:z][1])) - @test getname.(collect(neighbors(model, node; sorted = true))) == [:z, :x, :y] + @test getname.(collect(neighbors(model, node))) == [:z, :x, :y] end @testitem "filter(::Predicate, ::Model)" begin @@ -930,29 +949,27 @@ end model = create_model() ctx = getcontext(model) - x = getorcreate!(model, ctx, :x, nothing) + x = GraphPPL.add_atomic_factor_node!(model, ctx, sum) y = getorcreate!(model, ctx, :y, nothing) + add_edge!(model, x, y, :interface) @test ne(model) == 1 @test_throws MethodError add_edge!(model, x, y, 123) - - add_edge!(model, generate_nodelabel(model, :factor_node), generate_nodelabel(model, :factor_node2), :interface) - @test ne(model) == 1 end @testitem "add_edge!(::Model, ::NodeLabel, ::Vector{NodeLabel}, ::Symbol)" begin import GraphPPL: create_model, getcontext, nv, ne, NodeData, NodeLabel, EdgeLabel, add_edge!, getorcreate! model = create_model() ctx = getcontext(model) - x = getorcreate!(model, ctx, :x, nothing) + x = GraphPPL.add_atomic_factor_node!(model, ctx, sum) y = getorcreate!(model, ctx, :y, nothing) variable_nodes = [getorcreate!(model, ctx, i, nothing) for i in [:a, :b, :c]] - add_edge!(model, y, variable_nodes, :interface) + add_edge!(model, x, variable_nodes, :interface) - @test ne(model) == 3 && model[y, variable_nodes[1]] == EdgeLabel(:interface, 1) + @test ne(model) == 3 && model[x, variable_nodes[1]] == EdgeLabel(:interface, 1) end @testitem "default_parametrization" begin @@ -1004,8 +1021,8 @@ end # Test 2: Stochastic atomic call returns a new node node_id = make_node!(model, ctx, Normal, x, (μ = 0, σ = 1)) @test nv(model) == 4 - @test getname.(edges(model, label_for(model.graph, 2))) == [:out, :μ, :σ] - @test getname.(edges(model, label_for(model.graph, 2))) == [:out, :μ, :σ] + @test getname.(edges(model, label_for(model.graph, 2))) == (:out, :μ, :σ) + @test getname.(edges(model, label_for(model.graph, 2))) == (:out, :μ, :σ) # Test 3: Stochastic atomic call with an AbstractArray as rhs_interfaces model = create_model() @@ -1038,8 +1055,8 @@ end x = getorcreate!(model, ctx, :x, nothing) node_id = make_node!(model, ctx, Normal, x, [0, 1]) @test nv(model) == 4 - @test getname.(edges(model, label_for(model.graph, 2))) == [:out, :μ, :σ] - @test getname.(edges(model, label_for(model.graph, 2))) == [:out, :μ, :σ] + @test getname.(edges(model, label_for(model.graph, 2))) == (:out, :μ, :σ) + @test getname.(edges(model, label_for(model.graph, 2))) == (:out, :μ, :σ) @test factorization_constraint(model[label_for(model.graph, 2)]) == BitSetTuple(3) # Test 7: Stochastic node with instantiated object @@ -1072,8 +1089,8 @@ end out = getorcreate!(model, ctx, :out, nothing) make_node!(model, ctx, ArbitraryNode, out, [1, 1]; __debug__ = false) @test nv(model) == 4 - @test getname.(edges(model, label_for(model.graph, 2))) == [:out, :in, :in] - @test getname.(edges(model, label_for(model.graph, 2))) == [:out, :in, :in] + @test getname.(edges(model, label_for(model.graph, 2))) == (:out, :in, :in) + @test getname.(edges(model, label_for(model.graph, 2))) == (:out, :in, :in) #Test 10: Deterministic node with keyword arguments function abc(; a = 1, b = 2) @@ -1150,8 +1167,8 @@ end # Test 1: Stochastic atomic call returns a new node node_id = materialize_factor_node!(model, ctx, Normal, (out = x, μ = 0, σ = 1)) @test nv(model) == 4 - @test getname.(edges(model, label_for(model.graph, 2))) == [:out, :μ, :σ] - @test getname.(edges(model, label_for(model.graph, 2))) == [:out, :μ, :σ] + @test getname.(edges(model, label_for(model.graph, 2))) == (:out, :μ, :σ) + @test getname.(edges(model, label_for(model.graph, 2))) == (:out, :μ, :σ) # Test 3: Stochastic atomic call with an AbstractArray as rhs_interfaces model = create_model() @@ -1223,7 +1240,7 @@ end model = create_model() ctx = getcontext(model) x = getorcreate!(model, ctx, :x, nothing) - y = getorcreate!(model, ctx, :y, nothing) + y = GraphPPL.add_atomic_factor_node!(model, ctx, sum) z = getorcreate!(model, ctx, :z, nothing) w = getorcreate!(model, ctx, :w, nothing) @@ -1289,3 +1306,20 @@ end @test sinterfaces[i] === interface end end + +@testitem "sort_interfaces" begin + import GraphPPL: sort_interfaces + include("model_zoo.jl") + + # Test 1: Test that sort_interfaces sorts the interfaces in the correct order + @test sort_interfaces(NormalMeanVariance, (μ = 1, σ = 1, out = 1)) == (out = 1, μ = 1, σ = 1) + @test sort_interfaces(NormalMeanVariance, (out = 1, μ = 1, σ = 1)) == (out = 1, μ = 1, σ = 1) + @test sort_interfaces(NormalMeanVariance, (σ = 1, out = 1, μ = 1)) == (out = 1, μ = 1, σ = 1) + @test sort_interfaces(NormalMeanVariance, (σ = 1, μ = 1, out = 1)) == (out = 1, μ = 1, σ = 1) + @test sort_interfaces(NormalMeanPrecision, (μ = 1, τ = 1, out = 1)) == (out = 1, μ = 1, τ = 1) + @test sort_interfaces(NormalMeanPrecision, (out = 1, μ = 1, τ = 1)) == (out = 1, μ = 1, τ = 1) + @test sort_interfaces(NormalMeanPrecision, (τ = 1, out = 1, μ = 1)) == (out = 1, μ = 1, τ = 1) + @test sort_interfaces(NormalMeanPrecision, (τ = 1, μ = 1, out = 1)) == (out = 1, μ = 1, τ = 1) + + @test_throws ErrorException sort_interfaces(NormalMeanVariance, (σ = 1, μ = 1, τ = 1)) +end diff --git a/test/meta_engine_tests.jl b/test/meta_engine_tests.jl index 65a7a9eb..7d51e99b 100644 --- a/test/meta_engine_tests.jl +++ b/test/meta_engine_tests.jl @@ -144,6 +144,7 @@ end context = GraphPPL.getcontext(model) metadata = MetaObject(FactorMetaDescriptor(NormalMeanVariance, (IndexedVariable(:x, 1), IndexedVariable(:y, nothing))), (meta = SomeMeta(), other = 1)) apply!(model, context, metadata) + @show GraphPPL.neighbors(model, context[:x][1]), GraphPPL.neighbors(model, context[:y]) node = first(intersect(GraphPPL.neighbors(model, context[:x][1]), GraphPPL.neighbors(model, context[:y]))) @test meta(model[node]) == SomeMeta() @test options(model[node]).others[:other] == 1