Skip to content

Commit

Permalink
Merge pull request #259 from ReactiveBayes/add_id_to_nodedata
Browse files Browse the repository at this point in the history
Add id to NodeData
  • Loading branch information
wouterwln authored Dec 11, 2024
2 parents 85e8a31 + 9775fe1 commit a883281
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 30 deletions.
8 changes: 5 additions & 3 deletions src/graph_engine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ Base.broadcastable(label::NodeLabel) = Ref(label)

getname(label::NodeLabel) = label.name
getname(labels::ResizableArray{T, V, N} where {T <: NodeLabel, V, N}) = getname(first(labels))
getid(label::NodeLabel) = label.global_counter
iterate(label::NodeLabel) = (label, nothing)
iterate(label::NodeLabel, any) = nothing

Expand Down Expand Up @@ -765,9 +766,10 @@ mutable struct NodeData
const context :: Context
const properties :: Union{VariableNodeProperties, FactorNodeProperties{NodeData}}
const extra :: UnorderedDictionary{Symbol, Any}
const id :: Int
end

NodeData(context, properties) = NodeData(context, properties, UnorderedDictionary{Symbol, Any}())
NodeData(context, properties, id) = NodeData(context, properties, UnorderedDictionary{Symbol, Any}(), id)

function Base.show(io::IO, nodedata::NodeData)
context = getcontext(nodedata)
Expand Down Expand Up @@ -1529,7 +1531,7 @@ end
function __add_variable_node!(model::Model, context::Context, options::NodeCreationOptions, name::Symbol, index)
# In theory plugins are able to overwrite this
potential_label = generate_nodelabel(model, name)
potential_nodedata = NodeData(context, convert(VariableNodeProperties, name, index, options))
potential_nodedata = NodeData(context, convert(VariableNodeProperties, name, index, options), getid(potential_label))
label, nodedata = preprocess_plugins(
UnionPluginType(VariableNodePlugin(), FactorAndVariableNodesPlugin()), model, context, potential_label, potential_nodedata, options
)
Expand Down Expand Up @@ -1643,7 +1645,7 @@ function add_atomic_factor_node!(model::Model, context::Context, options::NodeCr
factornode_id = generate_factor_nodelabel(context, fform)

potential_label = generate_nodelabel(model, fform)
potential_nodedata = NodeData(context, convert(FactorNodeProperties, fform, options))
potential_nodedata = NodeData(context, convert(FactorNodeProperties, fform, options), getid(potential_label))

label, nodedata = preprocess_plugins(
UnionPluginType(FactorNodePlugin(), FactorAndVariableNodesPlugin()), model, context, potential_label, potential_nodedata, options
Expand Down
41 changes: 22 additions & 19 deletions test/graph_engine_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ end

@testset "FactorNodeProperties" begin
properties = FactorNodeProperties(fform = String)
nodedata = NodeData(context, properties)
nodedata = NodeData(context, properties, 1)

@test getcontext(nodedata) === context
@test getproperties(nodedata) === properties
Expand All @@ -135,7 +135,7 @@ end

@testset "VariableNodeProperties" begin
properties = VariableNodeProperties(name = :x, index = 1)
nodedata = NodeData(context, properties)
nodedata = NodeData(context, properties, 1)

@test getcontext(nodedata) === context
@test getproperties(nodedata) === properties
Expand Down Expand Up @@ -183,7 +183,7 @@ end
context = getcontext(model)

@testset for properties in (FactorNodeProperties(fform = String), VariableNodeProperties(name = :x, index = 1))
nodedata = NodeData(context, properties)
nodedata = NodeData(context, properties, 1)

@test !hasextra(nodedata, :a)
@test getextra(nodedata, :a, 2) === 2
Expand Down Expand Up @@ -552,7 +552,10 @@ end

function GraphPPL.preprocess_plugin(::AnArbitraryPluginForChangingOptions, model, context, label, nodedata, options)
# Here we replace the original options entirely
return label, NodeData(context, convert(VariableNodeProperties, :x, nothing, NodeCreationOptions(kind = :constant, value = 1.0)))
return label,
NodeData(
context, convert(VariableNodeProperties, :x, nothing, NodeCreationOptions(kind = :constant, value = 1.0)), GraphPPL.getid(label)
)
end

for model_fn in ModelsInTheZooWithoutArguments
Expand Down Expand Up @@ -933,13 +936,13 @@ end

model = create_test_model()
ctx = getcontext(model)
model[NodeLabel(, 1)] = NodeData(ctx, VariableNodeProperties(name = , index = nothing))
model[NodeLabel(, 1)] = NodeData(ctx, VariableNodeProperties(name = , index = nothing), 1)
@test nv(model) == 1 && ne(model) == 0

model[NodeLabel(:x, 2)] = NodeData(ctx, VariableNodeProperties(name = :x, index = nothing))
model[NodeLabel(:x, 2)] = NodeData(ctx, VariableNodeProperties(name = :x, index = nothing), 2)
@test nv(model) == 2 && ne(model) == 0

model[NodeLabel(sum, 3)] = NodeData(ctx, FactorNodeProperties(fform = sum))
model[NodeLabel(sum, 3)] = NodeData(ctx, FactorNodeProperties(fform = sum), 3)
@test nv(model) == 3 && ne(model) == 0

@test_throws MethodError model[0] = 1
Expand All @@ -959,8 +962,8 @@ end
μ = NodeLabel(, 1)
xref = NodeLabel(:x, 2)

model[μ] = NodeData(ctx, VariableNodeProperties(name = , index = nothing))
model[xref] = NodeData(ctx, VariableNodeProperties(name = :x, index = nothing))
model[μ] = NodeData(ctx, VariableNodeProperties(name = , index = nothing), 1)
model[xref] = NodeData(ctx, VariableNodeProperties(name = :x, index = nothing), 2)
model[μ, xref] = EdgeLabel(:interface, 1)

@test ne(model) == 1
Expand Down Expand Up @@ -990,7 +993,7 @@ end
model = create_test_model()
ctx = getcontext(model)
label = NodeLabel(:x, 1)
model[label] = NodeData(ctx, VariableNodeProperties(name = :x, index = nothing))
model[label] = NodeData(ctx, VariableNodeProperties(name = :x, index = nothing), 1)
@test isa(model[label], NodeData)
@test isa(getproperties(model[label]), VariableNodeProperties)
@test_throws KeyError model[NodeLabel(:x, 10)]
Expand Down Expand Up @@ -1024,8 +1027,8 @@ end
@test nv(model) == 0
@test ne(model) == 0

model[NodeLabel(:a, 1)] = NodeData(ctx, VariableNodeProperties(name = :a, index = nothing))
model[NodeLabel(:b, 2)] = NodeData(ctx, VariableNodeProperties(name = :b, index = nothing))
model[NodeLabel(:a, 1)] = NodeData(ctx, VariableNodeProperties(name = :a, index = nothing), 1)
model[NodeLabel(:b, 2)] = NodeData(ctx, VariableNodeProperties(name = :b, index = nothing), 2)
@test !isempty(model)
@test nv(model) == 2
@test ne(model) == 0
Expand Down Expand Up @@ -1059,8 +1062,8 @@ end
ctx = getcontext(model)
a = NodeLabel(:a, 1)
b = NodeLabel(:b, 2)
model[a] = NodeData(ctx, VariableNodeProperties(name = :a, index = nothing))
model[b] = NodeData(ctx, FactorNodeProperties(fform = sum))
model[a] = NodeData(ctx, VariableNodeProperties(name = :a, index = nothing), 1)
model[b] = NodeData(ctx, FactorNodeProperties(fform = sum), 2)
@test !has_edge(model, a, b)
@test !has_edge(model, b, a)
add_edge!(model, b, getproperties(model[b]), a, :edge, 1)
Expand All @@ -1069,7 +1072,7 @@ end
@test length(edges(model)) == 1

c = NodeLabel(:c, 2)
model[c] = NodeData(ctx, FactorNodeProperties(fform = sum))
model[c] = NodeData(ctx, FactorNodeProperties(fform = sum), 2)
@test !has_edge(model, a, c)
@test !has_edge(model, c, a)
add_edge!(model, c, getproperties(model[c]), a, :edge, 2)
Expand Down Expand Up @@ -1109,8 +1112,8 @@ end

a = NodeLabel(:a, 1)
b = NodeLabel(:b, 2)
model[a] = NodeData(ctx, FactorNodeProperties(fform = sum))
model[b] = NodeData(ctx, VariableNodeProperties(name = :b, index = nothing))
model[a] = NodeData(ctx, FactorNodeProperties(fform = sum), 1)
model[b] = NodeData(ctx, VariableNodeProperties(name = :b, index = nothing), 2)
add_edge!(model, a, getproperties(model[a]), b, :edge, 1)
@test collect(neighbors(model, NodeLabel(:a, 1))) == [NodeLabel(:b, 2)]

Expand All @@ -1120,9 +1123,9 @@ end
b = ResizableArray(NodeLabel, Val(1))
for i in 1:3
a[i] = NodeLabel(:a, i)
model[a[i]] = NodeData(ctx, FactorNodeProperties(fform = sum))
model[a[i]] = NodeData(ctx, FactorNodeProperties(fform = sum), i)
b[i] = NodeLabel(:b, i)
model[b[i]] = NodeData(ctx, VariableNodeProperties(name = :b, index = i))
model[b[i]] = NodeData(ctx, VariableNodeProperties(name = :b, index = i), i)
add_edge!(model, a[i], getproperties(model[a[i]]), b[i], :edge, i)
end
for n in b
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -862,35 +862,35 @@ end
])

variable = ResolvedIndexedVariable(:w, 2:3, context)
node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :w, index = 2))
node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :w, index = 2), 2)
@test node_data variable

variable = ResolvedIndexedVariable(:w, 2:3, context)
node_data = GraphPPL.NodeData(GraphPPL.Context(), VariableNodeProperties(name = :w, index = 2))
node_data = GraphPPL.NodeData(GraphPPL.Context(), VariableNodeProperties(name = :w, index = 2), 2)
@test !(node_data variable)

variable = ResolvedIndexedVariable(:w, 2, context)
node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :w, index = 2))
node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :w, index = 2), 2)
@test node_data variable

variable = ResolvedIndexedVariable(:w, SplittedRange(2, 3), context)
node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :w, index = 2))
node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :w, index = 2), 2)
@test node_data variable

variable = ResolvedIndexedVariable(:w, SplittedRange(10, 15), context)
node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :w, index = 2))
node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :w, index = 2), 2)
@test !(node_data variable)

variable = ResolvedIndexedVariable(:x, nothing, context)
node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :x, index = 2))
node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :x, index = 2), 2)
@test node_data variable

variable = ResolvedIndexedVariable(:x, nothing, context)
node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :x, index = nothing))
node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :x, index = nothing), 1)
@test node_data variable

variable = ResolvedIndexedVariable(:prec, 3, context)
node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :prec, index = (1, 3)))
node_data = GraphPPL.NodeData(context, VariableNodeProperties(name = :prec, index = (1, 3)), 2)
@test node_data variable
end

Expand Down

0 comments on commit a883281

Please sign in to comment.