Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Factorization constraint outputs in the correct ordering #144

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions benchmark/model_creation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,9 @@ function benchmark_model_creation()
for i in 10 .^ range(1, stop=3)
SUITE["create HGF of depth $i"] = @benchmarkable create_hgf($i)
end
for i in 10 .^ range(2, stop=6)
n_nodes = 10^i
SUITE["create model with array of length $i"] = @benchmarkable create_longarray($n_nodes)
end
return SUITE
end
101 changes: 32 additions & 69 deletions src/graph_engine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,20 @@

Base.show(io::IO, label::NodeLabel) = print(io, label.name, "_", label.global_counter)

struct EdgeLabel
name::Symbol
index::Union{Int, Nothing}
end

getname(label::EdgeLabel) = label.name
getname(labels::Tuple) = map(group -> getname(group), labels)

to_symbol(label::EdgeLabel) = to_symbol(label, label.index)
to_symbol(label::EdgeLabel, ::Nothing) = label.name
to_symbol(label::EdgeLabel, ::Int64) = Symbol(string(label.name) * "[" * string(label.index) * "]")

Check warning on line 83 in src/graph_engine.jl

View check run for this annotation

Codecov / codecov/patch

src/graph_engine.jl#L81-L83

Added lines #L81 - L83 were not covered by tests

Base.show(io::IO, label::EdgeLabel) = print(io, to_symbol(label))

Check warning on line 85 in src/graph_engine.jl

View check run for this annotation

Codecov / codecov/patch

src/graph_engine.jl#L85

Added line #L85 was not covered by tests

mutable struct VariableNodeOptions
value::Any
functional_form::Any
Expand Down Expand Up @@ -105,7 +119,7 @@

Data associated with a variable node in a probabilistic graphical model.
"""
struct VariableNodeData
mutable struct VariableNodeData
name::Symbol
options::VariableNodeOptions
index::Any
Expand Down Expand Up @@ -167,6 +181,7 @@
context::Any
factorization_constraint::Any
options::FactorNodeOptions
neighbors::NTuple{N, Tuple{NodeLabel, EdgeLabel}} where {N}
end

fform(node::FactorNodeData) = node.fform
Expand Down Expand Up @@ -226,28 +241,15 @@
Base.last(proxied::ProxyLabel, ::ProxyLabel) = last(proxied)
Base.last(proxied, ::ProxyLabel) = proxied

struct EdgeLabel
name::Symbol
index::Union{Int, Nothing}
end

getname(label::EdgeLabel) = label.name
getname(labels::Tuple) = map(group -> getname(group), labels)

to_symbol(label::EdgeLabel) = to_symbol(label, label.index)
to_symbol(label::EdgeLabel, ::Nothing) = label.name
to_symbol(label::EdgeLabel, ::Int64) = Symbol(string(label.name) * "[" * string(label.index) * "]")

Base.show(io::IO, label::EdgeLabel) = print(io, to_symbol(label))

Model(graph::MetaGraph) = Model(graph, Base.RefValue(0))

Base.setindex!(model::Model, val::NodeData, key::NodeLabel) = Base.setindex!(model.graph, val, key)
Base.setindex!(model::Model, val::EdgeLabel, src::NodeLabel, dst::NodeLabel) = Base.setindex!(model.graph, val, src, dst)
Base.getindex(model::Model) = Base.getindex(model.graph)
Base.getindex(model::Model, key::NodeLabel) = Base.getindex(model.graph, key)
Base.getindex(model::Model, src::NodeLabel, dst::NodeLabel) = Base.getindex(model.graph, src, dst)
Base.getindex(model::Model, keys::AbstractArray{NodeLabel}) = [model[key] for key in keys]
Base.getindex(model::Model, keys::AbstractArray{NodeLabel}) = map(key -> model[key], keys)
Base.getindex(model::Model, keys::NTuple{N, NodeLabel}) where {N} = collect(map(key -> model[key], keys))

Base.getindex(model::Model, keys::Base.Generator) = [model[key] for key in keys]

Expand All @@ -272,59 +274,18 @@
Graphs.nv(model::Model) = Graphs.nv(model.graph)
Graphs.ne(model::Model) = Graphs.ne(model.graph)
Graphs.edges(model::Model) = Graphs.edges(model.graph)
MetaGraphsNext.label_for(model::Model, node_id::Int) = MetaGraphsNext.label_for(model.graph, node_id)

function retrieve_interface_position(interfaces::StaticInterfaces{I}, x::EdgeLabel, max_length::Int) where {I}
index = x.index === nothing ? 0 : x.index
position = findfirst(isequal(x.name), I)
position = position === nothing ? begin
@warn(lazy"Interface $(x.name) not found in $I")
0
end : position
return max_length * findfirst(isequal(x.name), I) + index
end

function __sortperm(model::Model, node::NodeLabel, edges::AbstractArray)
fform = model[node].fform
indices = [e.index for e in edges]
names = unique([e.name for e in edges])
interfaces = GraphPPL.interfaces(fform, static(length(names)))
max_length = any(x -> x !== nothing, indices) ? maximum(indices[indices .!= nothing]) : 1
perm = sortperm(edges, by = (x -> retrieve_interface_position(interfaces, x, max_length)))
return perm
end

__get_neighbors(model::Model, node::NodeLabel) = Iterators.map(neighbor -> label_for(model, neighbor), MetaGraphsNext.neighbors(model.graph, code_for(model.graph, node)))

__neighbors(model::Model, node::NodeLabel; sorted = false) = __neighbors(model, node, model[node]; sorted = sorted)
__neighbors(model::Model, node::NodeLabel, node_data::VariableNodeData; sorted = false) = __get_neighbors(model, node)
__neighbors(model::Model, node::NodeLabel, node_data::FactorNodeData; sorted = false) = __neighbors(model, node, static(sorted))

__neighbors(model::Model, node::NodeLabel, ::False) = __get_neighbors(model, node)
function __neighbors(model::Model, node::NodeLabel, ::True)
neighbors = collect(__get_neighbors(model, node))
edges = __get_edges(model, node, neighbors)
perm = __sortperm(model, node, edges)
return neighbors[perm]
end

Graphs.neighbors(model::Model, node::NodeLabel; sorted = false) = __neighbors(model, node; sorted = sorted)
Graphs.neighbors(model::Model, nodes::AbstractArray; sorted = false) = reduce(union, Graphs.neighbors.(Ref(model), nodes; sorted = sorted))

Graphs.vertices(model::Model) = MetaGraphsNext.vertices(model.graph)
MetaGraphsNext.labels(model::Model) = MetaGraphsNext.labels(model.graph)
Graphs.neighbors(model::Model, node::NodeLabel) = Graphs.neighbors(model, node, model[node])
Graphs.neighbors(model::Model, node::NodeLabel, nodedata::FactorNodeData) = map(neighbor -> neighbor[1], nodedata.neighbors)
Graphs.neighbors(model::Model, node::NodeLabel, nodedata::VariableNodeData) = MetaGraphsNext.neighbor_labels(model.graph, node)
Graphs.neighbors(model::Model, nodes::AbstractArray{<:NodeLabel}) = Iterators.flatten(map(node -> Graphs.neighbors(model, node), nodes))

__get_edges(model::Model, node::NodeLabel, neighbors) = getindex.(Ref(model), Ref(node), neighbors)
__edges(model::Model, node::NodeLabel, node_data::VariableNodeData; sorted = false) = __get_edges(model, node, __get_neighbors(model, node))
__edges(model::Model, node::NodeLabel, node_data::FactorNodeData; sorted = false) = __edges(model, node, static(sorted))
__edges(model::Model, node::NodeLabel, ::False) = __get_edges(model, node, __get_neighbors(model, node))
function __edges(model::Model, node::NodeLabel, ::True)
neighbors = __get_neighbors(model, node)
edges = __get_edges(model, node, neighbors)
perm = __sortperm(model, node, edges)
return edges[perm]
Graphs.edges(model::Model, node::NodeLabel) = Graphs.edges(model, node, model[node])
Graphs.edges(model::Model, node::NodeLabel, nodedata::FactorNodeData) = map(neighbor -> neighbor[2], nodedata.neighbors)
function Graphs.edges(model::Model, node::NodeLabel, nodedata::VariableNodeData)
return Tuple(model[node, dst] for dst in MetaGraphsNext.neighbor_labels(model.graph, node))
end
Graphs.edges(model::Model, node::NodeLabel; sorted = false) = __edges(model, node, model[node]; sorted = sorted)
Graphs.edges(model::Model, nodes::AbstractArray{<:NodeLabel}) = Tuple(Iterators.flatten(map(node -> Graphs.edges(model, node), nodes)))

abstract type AbstractModelFilterPredicate end

Expand Down Expand Up @@ -777,7 +738,7 @@
function add_atomic_factor_node!(model::Model, context::Context, fform; __options__ = FactorNodeOptions())
factornode_id = generate_factor_nodelabel(context, fform)
factornode_label = generate_nodelabel(model, fform)
model[factornode_label] = FactorNodeData(fform, context, nothing, __options__)
model[factornode_label] = FactorNodeData(fform, context, nothing, __options__, ())
context.factor_nodes[factornode_id] = factornode_label
return factornode_label
end
Expand Down Expand Up @@ -811,7 +772,9 @@
iterator(interfaces::NamedTuple) = zip(keys(interfaces), values(interfaces))

function add_edge!(model::Model, factor_node_id::NodeLabel, variable_node_id::Union{ProxyLabel, NodeLabel}, interface_name::Symbol; index = nothing)
model.graph[unroll(variable_node_id), factor_node_id] = EdgeLabel(interface_name, index)
label = EdgeLabel(interface_name, index)
model[factor_node_id].neighbors = (model[factor_node_id].neighbors..., (unroll(variable_node_id), label))
model.graph[unroll(variable_node_id), factor_node_id] = label
end

function add_edge!(model::Model, factor_node_id::NodeLabel, variable_nodes::Union{AbstractArray, Tuple, NamedTuple}, interface_name::Symbol; index = 1)
Expand All @@ -824,7 +787,7 @@
increase_index(x::AbstractArray) = length(x)

function add_factorization_constraint!(model::Model, factor_node_id::NodeLabel)
out_degree = outdegree(model.graph, code_for(model.graph, factor_node_id))
out_degree = length(model[factor_node_id].neighbors)
constraint = BitSetTuple(out_degree)
set_factorization_constraint!(model[factor_node_id], constraint)
end
Expand Down
28 changes: 14 additions & 14 deletions test/constraints_engine_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ end
ctx = GraphPPL.getcontext(model)
node = ctx[NormalMeanVariance, 2]
materialize_constraints!(model, node)
@test get_constraint_names(factorization_constraint(model[node])) == ((:μ, :σ, :out),)
@test get_constraint_names(factorization_constraint(model[node])) == ((:out, :μ, :σ),)
materialize_constraints!(model, ctx[NormalMeanVariance, 1])
@test get_constraint_names(factorization_constraint(model[ctx[NormalMeanVariance, 1]])) == ((:out,), (:μ,), (:σ,))

Expand All @@ -513,7 +513,7 @@ end
node = ctx[NormalMeanVariance, 2]
GraphPPL.save_constraint!(model[node], BitSetTuple([[1], [2, 3], [2, 3]]))
materialize_constraints!(model, node)
@test get_constraint_names(factorization_constraint(model[node])) == ((:μ,), (:σ, :out))
@test get_constraint_names(factorization_constraint(model[node])) == ((:out,), (:μ, :σ))

# Test 3: Check that materialize_constraints! throws if the constraint is not a valid partition
model = create_terminated_model(simple_model)
Expand Down Expand Up @@ -653,7 +653,7 @@ end
(ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:w, 2, context),)), ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:w, 3, context),)))
)
@test GraphPPL.is_applicable(neighbors, constraint)
@test GraphPPL.convert_to_bitsets(model, normal_node, neighbors, constraint) == BitSetTuple([[1, 3], [2, 3], [1, 2, 3]])
@test GraphPPL.convert_to_bitsets(model, normal_node, neighbors, constraint) == BitSetTuple([[1, 2, 3], [1, 2], [1, 3]])
end

let constraint = ResolvedFactorizationConstraint(
Expand All @@ -668,7 +668,7 @@ end
(ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:w, SplittedRange(2, 3), context),)),)
)
@test GraphPPL.is_applicable(neighbors, constraint)
@test GraphPPL.convert_to_bitsets(model, normal_node, neighbors, constraint) == BitSetTuple([[1, 3], [2, 3], [1, 2, 3]])
@test GraphPPL.convert_to_bitsets(model, normal_node, neighbors, constraint) == BitSetTuple([[1, 2, 3], [1, 2], [1, 3]])
end

let constraint = ResolvedFactorizationConstraint(
Expand All @@ -690,9 +690,9 @@ end
)
)
@test GraphPPL.is_applicable(neighbors, constraint)
@test GraphPPL.convert_to_bitsets(model, normal_node, neighbors, constraint) == BitSetTuple([[1], [2, 3], [2, 3]])
@test GraphPPL.convert_to_bitsets(model, normal_node, neighbors, constraint) == BitSetTuple([[1, 3], [2], [1, 3]])
apply!(model, normal_node, constraint)
@test GraphPPL.factorization_constraint(model[normal_node]) == BitSetTuple([[1], [2, 3], [2, 3]])
@test GraphPPL.factorization_constraint(model[normal_node]) == BitSetTuple([[1, 3], [2], [1, 3]])
end

let constraint = ResolvedFactorizationConstraint(
Expand All @@ -703,7 +703,7 @@ end
)
)
@test GraphPPL.is_applicable(neighbors, constraint)
@test GraphPPL.convert_to_bitsets(model, normal_node, neighbors, constraint) == BitSetTuple([[1, 2], [1, 2], [3]])
@test GraphPPL.convert_to_bitsets(model, normal_node, neighbors, constraint) == BitSetTuple([[1], [2, 3], [2, 3]])
end

model = create_terminated_model(multidim_array)
Expand Down Expand Up @@ -776,7 +776,7 @@ end
@test GraphPPL.is_applicable(neighbors, constraint)

# This test should throw since we cannot resolve the constraint
@test_broken GraphPPL.convert_to_bitsets(model, normal_node, neighbors, constraint)
@test_broken (try GraphPPL.convert_to_bitsets(model, normal_node, neighbors, constraint) catch e; e; end) isa Exception
end
end

Expand Down Expand Up @@ -841,7 +841,7 @@ end
@test message_constraint(model[node]) === nothing
end
@test getname(factorization_constraint(model[ctx[NormalMeanVariance, 1]])) == ((:out,), (:μ,), (:σ,))
@test getname(factorization_constraint(model[ctx[NormalMeanVariance, 2]])) == ((:μ, :out), (:σ,))
@test getname(factorization_constraint(model[ctx[NormalMeanVariance, 2]])) == ((:out, :μ), (:σ,))

# Test constriants macro with nested model
model = create_terminated_model(outer)
Expand All @@ -868,7 +868,7 @@ end

@test fform_constraint(model[ctx[:y]]) == NormalMeanVariance()
for node in filter(GraphPPL.as_node(NormalMeanVariance) & GraphPPL.as_context(inner_inner), model)
@test getname(factorization_constraint(model[node])) == ((:μ, :σ), (:out,))
@test getname(factorization_constraint(model[node])) == ((:out,), (:μ, :σ))
end

# Test with specifying specific submodel
Expand All @@ -881,9 +881,9 @@ end
end

apply!(model, constraints)
@test getname(factorization_constraint(model[ctx[child_model, 1][NormalMeanVariance, 1]])) == ((:μ, :out), (:σ,))
@test getname(factorization_constraint(model[ctx[child_model, 1][NormalMeanVariance, 1]])) == ((:out, :μ), (:σ,))
for i in 2:99
@test getname(factorization_constraint(model[ctx[child_model, i][NormalMeanVariance, 1]])) == ((:μ, :out, :σ),)
@test getname(factorization_constraint(model[ctx[child_model, i][NormalMeanVariance, 1]])) == ((:out, :μ, :σ),)
end

# Test with specifying general submodel
Expand All @@ -896,9 +896,9 @@ end
end

apply!(model, constraints)
@test getname(factorization_constraint(model[ctx[child_model, 1][NormalMeanVariance, 1]])) == ((:μ, :out), (:σ,))
@test getname(factorization_constraint(model[ctx[child_model, 1][NormalMeanVariance, 1]])) == ((:out, :μ), (:σ,))
for node in filter(GraphPPL.as_node(NormalMeanVariance) & GraphPPL.as_context(child_model), model)
@test getname(factorization_constraint(model[node])) == ((:μ, :out), (:σ,))
@test getname(factorization_constraint(model[node])) == ((:out, :μ), (:σ,))
end

# Test with ambiguous constraints
Expand Down
Loading
Loading