Skip to content

Commit

Permalink
Merge pull request #130 from biaslab/dev-performance
Browse files Browse the repository at this point in the history
Performance improvements
  • Loading branch information
bvdmitri authored Nov 16, 2023
2 parents 73e4ef6 + 30a6988 commit 7d59ed5
Show file tree
Hide file tree
Showing 7 changed files with 157 additions and 38 deletions.
32 changes: 27 additions & 5 deletions ext/GraphPPLDistributionsExt.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,34 @@
module GraphPPLDistributionsExt
using GraphPPL, Distributions

using GraphPPL, Distributions, Static

GraphPPL.NodeBehaviour(::Type{<:Distributions.Distribution}) = GraphPPL.Stochastic()

function GraphPPL.default_parametrization(::GraphPPL.Atomic, t::Type{<:Distributions.Distribution}, interface_values)
field_names = fieldnames(t)
@assert length(interface_values) == length(field_names) "Distribution $t has $(length(field_names)) fields $(field_names) but $(length(interface_values)) values were provided."
return NamedTuple{fieldnames(t)}(interface_values)
return distributions_ext_default_parametrization(t, distributions_ext_input_interfaces(t), interface_values)
end

function distributions_ext_default_parametrization(t::Type{<:Distributions.Distribution}, ::GraphPPL.StaticInterfaces{interfaces}, interface_values) where {interfaces}
@assert length(interface_values) == length(interfaces) "Distribution $t has $(length(interfaces)) fields $(interfaces) but $(length(interface_values)) values were provided."
return NamedTuple{interfaces}(interface_values)
end

function GraphPPL.interfaces(T::Type{<:Distributions.Distribution}, _)
return distributions_ext_interfaces(T)
end

@generated function distributions_ext_input_interfaces(::Type{T}) where {T}
fnames = fieldnames(T)
return quote
GraphPPL.StaticInterfaces(($(map(QuoteNode, fnames)...), ))
end
end

@generated function distributions_ext_interfaces(::Type{T}) where {T}
fnames = fieldnames(T)
return quote
GraphPPL.StaticInterfaces((:out, $(map(QuoteNode, fnames)...)))
end
end
GraphPPL.interfaces(t::Type{<:Distributions.Distribution}, val) = (:out, fieldnames(t)...)

end
45 changes: 27 additions & 18 deletions src/graph_engine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,15 @@ factor_nodes(model::Model) = Iterators.filter(node -> is_factor(model[node]), la
variable_nodes(model::Model) =
Iterators.filter(node -> is_variable(model[node]), labels(model))

"""
A structure that holds interfaces of a node in the type argument `I`. Used for dispatch.
"""
struct StaticInterfaces{I} end

StaticInterfaces(I::Tuple) = StaticInterfaces{I}()



struct ProxyLabel{T}
name::Symbol
index::T
Expand Down Expand Up @@ -294,16 +303,16 @@ Graphs.edges(model::Model) = collect(Graphs.edges(model.graph))
MetaGraphsNext.label_for(model::Model, node_id::Int) =
MetaGraphsNext.label_for(model.graph, node_id)

function retrieve_interface_position(interfaces, x::EdgeLabel, max_length::Int)
function retrieve_interface_position(interfaces::StaticInterfaces{I}, x::EdgeLabel, max_length::Int) where {I}
index = x.index === nothing ? 0 : x.index
position = findfirst(isequal(x.name), interfaces)
position = findfirst(isequal(x.name), I)
position =
position === nothing ?
begin
@warn(lazy"Interface $(x.name) not found in $interfaces")
@warn(lazy"Interface $(x.name) not found in $I")
0
end : position
return max_length * findfirst(isequal(x.name), interfaces) + index
return max_length * findfirst(isequal(x.name), I) + index
end

function __sortperm(model::Model, node::NodeLabel, edges::AbstractArray)
Expand Down Expand Up @@ -1016,8 +1025,8 @@ end
"""
Placeholder function that is defined for all Composite nodes and is invoked when inferring what interfaces are missing when a node is called
"""
interfaces(any_f, ::StaticInt{1}) = (:out,)
interfaces(any_f, any_val) = (:out, :in)
interfaces(any_f, ::StaticInt{1}) = StaticInterfaces((:out,))
interfaces(any_f, any_val) = StaticInterfaces((:out, :in))

"""
missing_interfaces(node_type, val, known_interfaces)
Expand All @@ -1032,22 +1041,22 @@ Returns the interfaces that are missing for a node. This is used when inferring
# Returns
- `missing_interfaces`: A `Vector` of the missing interfaces.
"""
function missing_interfaces(node_type, val::StaticInt{N} where {N}, known_interfaces)
all_interfaces = GraphPPL.interfaces(node_type, val)
missing_interfaces = Base.setdiff(all_interfaces, keys(known_interfaces))
return missing_interfaces
function missing_interfaces(fform, val, known_interfaces::NamedTuple)
return missing_interfaces(interfaces(fform, val), StaticInterfaces(keys(known_interfaces)))
end

function missing_interfaces(::StaticInterfaces{all_interfaces}, ::StaticInterfaces{present_interfaces}) where {all_interfaces, present_interfaces}
return StaticInterfaces(filter(interface -> interface present_interfaces, all_interfaces))
end

function prepare_interfaces(fform, lhs_interface, rhs_interfaces::NamedTuple)
missing_interface = GraphPPL.missing_interfaces(
fform,
static(length(rhs_interfaces) + 1),
rhs_interfaces,
)
@assert length(missing_interface) == 1 lazy"Expected only one missing interface, got $missing_interface of length $(length(missing_interface)) (node $fform with interfaces $(keys(rhs_interfaces)))))"
missing_interface = first(missing_interface)
# TODO check if we can construct NamedTuples a bit faster somewhere.
missing_interface = missing_interfaces(fform, static(length(rhs_interfaces)) + static(1), rhs_interfaces)
return prepare_interfaces(missing_interface, lhs_interface, rhs_interfaces)
end

function prepare_interfaces(::StaticInterfaces{I}, lhs_interface, rhs_interfaces::NamedTuple) where {I}
@assert length(I) == 1 lazy"Expected only one missing interface, got $I of length $(length(I)) (node $fform with interfaces $(keys(rhs_interfaces)))))"
missing_interface = first(I)
return NamedTuple{(missing_interface, keys(rhs_interfaces)...)}((
lhs_interface,
values(rhs_interfaces)...,
Expand Down
2 changes: 1 addition & 1 deletion src/model_macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,7 @@ function get_boilerplate_functions(ms_name, ms_args, num_interfaces)
function $ms_name end
GraphPPL.interfaces(::typeof($ms_name), val) = error($error_msg * " $val keywords")
GraphPPL.interfaces(::typeof($ms_name), ::GraphPPL.StaticInt{$num_interfaces}) =
Tuple($ms_args)
GraphPPL.StaticInterfaces(Tuple($ms_args))
GraphPPL.NodeType(::typeof($ms_name)) = GraphPPL.Composite()
GraphPPL.NodeBehaviour(::typeof($ms_name)) = GraphPPL.Stochastic()
end
Expand Down
22 changes: 22 additions & 0 deletions src/resizable_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,28 @@ function recursive_setindex!(
return nothing
end

function Base.isassigned(array::ResizableArray{T,V,N}, index::Integer...) where {T, V, N}
if length(index) !== N
return false
else
return recursive_isassigned(Val(N), array.data, index)
end
end

function recursive_isassigned(::Val{N}, array, indices) where {N}
findex = first(indices)
tindices = Base.tail(indices)
if isassigned(array, findex)
return recursive_isassigned(Val(N - 1), @inbounds(array[findex]), tindices)
else
return false
end
end

function recursive_isassigned(::Val{1}, array, index::Tuple{Integer})
return isassigned(array, first(index))
end

function getindex(array::ResizableArray{T,V,N}, index::UnitRange) where {T,V,N}
return ResizableArray(array.data[index])
end
Expand Down
20 changes: 10 additions & 10 deletions test/model_macro_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1362,25 +1362,25 @@ end
import GraphPPL: missing_interfaces, interfaces
function abc end

GraphPPL.interfaces(::typeof(abc), ::StaticInt{3}) = [:in1, :in2, :out]
GraphPPL.interfaces(::typeof(abc), ::StaticInt{3}) = GraphPPL.StaticInterfaces((:in1, :in2, :out))

@test missing_interfaces(abc, static(3), (in1 = :x, in2 = :y)) == [:out]
@test missing_interfaces(abc, static(3), (out = :y,)) == [:in1, :in2]
@test missing_interfaces(abc, static(3), Dict()) == [:in1, :in2, :out]
@test missing_interfaces(abc, static(3), (in1 = :x, in2 = :y)) == GraphPPL.StaticInterfaces((:out,))
@test missing_interfaces(abc, static(3), (out = :y,)) == GraphPPL.StaticInterfaces((:in1, :in2))
@test missing_interfaces(abc, static(3), NamedTuple()) == GraphPPL.StaticInterfaces((:in1, :in2, :out))

function xyz end

GraphPPL.interfaces(::typeof(xyz), ::StaticInt{0}) = []
@test missing_interfaces(xyz, static(0), (in1 = :x, in2 = :y)) == []
GraphPPL.interfaces(::typeof(xyz), ::StaticInt{0}) = GraphPPL.StaticInterfaces(())
@test missing_interfaces(xyz, static(0), (in1 = :x, in2 = :y)) == GraphPPL.StaticInterfaces(())

function foo end

GraphPPL.interfaces(::typeof(foo), ::StaticInt{2}) = (:a, :b)
@test missing_interfaces(foo, static(2), (a = 1, b = 2)) == []
GraphPPL.interfaces(::typeof(foo), ::StaticInt{2}) = GraphPPL.StaticInterfaces((:a, :b))
@test missing_interfaces(foo, static(2), (a = 1, b = 2)) == GraphPPL.StaticInterfaces(())

function bar end
GraphPPL.interfaces(::typeof(bar), ::StaticInt{2}) = (:in1, :in2, :out)
@test missing_interfaces(bar, static(2), (in1 = 1, in2 = 2, out = 3, test = 4)) == []
GraphPPL.interfaces(::typeof(bar), ::StaticInt{2}) = GraphPPL.StaticInterfaces((:in1, :in2, :out))
@test missing_interfaces(bar, static(2), (in1 = 1, in2 = 2, out = 3, test = 4)) == GraphPPL.StaticInterfaces(())
end

@testitem "keyword_expressions_to_named_tuple" begin
Expand Down
8 changes: 4 additions & 4 deletions test/model_zoo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ GraphPPL.NodeBehaviour(::Type{NormalMeanPrecision}) = GraphPPL.Stochastic()

GraphPPL.aliases(::Type{Normal}) = (Normal, NormalMeanVariance, NormalMeanPrecision)

GraphPPL.interfaces(::Type{NormalMeanVariance}, ::StaticInt{3}) = (:out, , )
GraphPPL.interfaces(::Type{NormalMeanPrecision}, ::StaticInt{3}) = (:out, , )
GraphPPL.interfaces(::Type{NormalMeanVariance}, ::StaticInt{3}) = GraphPPL.StaticInterfaces((:out, , ))
GraphPPL.interfaces(::Type{NormalMeanPrecision}, ::StaticInt{3}) = GraphPPL.StaticInterfaces((:out, , ))
GraphPPL.factor_alias(::Type{Normal}, ::Val{(:μ, :σ)}) = NormalMeanVariance
GraphPPL.factor_alias(::Type{Normal}, ::Val{(:μ, :τ)}) = NormalMeanPrecision

Expand All @@ -36,8 +36,8 @@ struct GammaShapeScale end

GraphPPL.aliases(::Type{Gamma}) = (Gamma, GammaShapeRate, GammaShapeScale)

GraphPPL.interfaces(::Type{GammaShapeRate}, ::StaticInt{3}) = (:out, , )
GraphPPL.interfaces(::Type{GammaShapeScale}, ::StaticInt{3}) = (:out, , )
GraphPPL.interfaces(::Type{GammaShapeRate}, ::StaticInt{3}) = GraphPPL.StaticInterfaces((:out, , ))
GraphPPL.interfaces(::Type{GammaShapeScale}, ::StaticInt{3}) = GraphPPL.StaticInterfaces((:out, , ))
GraphPPL.factor_alias(::Type{Gamma}, ::Val{(:α, :β)}) = GammaShapeRate
GraphPPL.factor_alias(::Type{Gamma}, ::Val{(:α, :θ)}) = GammaShapeScale

Expand Down
66 changes: 66 additions & 0 deletions test/resizable_array_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,69 @@
end

end

@testitem "isassigned" begin
import GraphPPL: ResizableArray

@testset begin
s = ResizableArray(Ref, Val(1))

# In the beginning everything is not assigned
for i in 1:100
@test !isassigned(s, i)
end

# Assign some random indices
rindex = rand(1:100, 10)

for i in rindex
s[i] = Ref(1)
end

for i in 1:100
if i rindex
@test !isassigned(s, i)
else
@test isassigned(s, i)
end
end
end

@testset begin
for N in 1:5
s = ResizableArray(Ref, Val(N))

for j in 1:N
@test !isassigned(s, ones(Int, j)...)
end

s[ones(Int, N)...] = Ref(1)

@test isassigned(s, ones(Int, N)...)

s[10ones(Int, N)...] = Ref(1)

@test isassigned(s, 10ones(Int, N)...)

for k in 2:9
@test !isassigned(s, k * ones(Int, N)...)
end

end
end

@testset begin
for N in 1:5, M in 1:5
s = ResizableArray(Ref, Val(N))
indices = CartesianIndex(ones(Int, N)...):CartesianIndex(M * ones(Int, N)...)

for index in indices
@test !isassigned(s, index.I...)
s[index.I...] = Ref(1)
@test isassigned(s, index.I...)
end

end
end

end

0 comments on commit 7d59ed5

Please sign in to comment.