Skip to content

Commit

Permalink
Merge pull request #57 from TensorBFS/jg/port-gtn
Browse files Browse the repository at this point in the history
Port GenericTensorNetworks
  • Loading branch information
mroavi authored Aug 14, 2023
2 parents 2615430 + 3059164 commit b292dd4
Show file tree
Hide file tree
Showing 10 changed files with 154 additions and 14 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
name = "TensorInference"
uuid = "c2297e78-99bd-40ad-871d-f50e56b81012"
authors = ["Jin-Guo Liu", "Martin Roa Villescas"]
version = "0.2.0"
version = "0.2.1"

[deps]
Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
GenericTensorNetworks = "3521c873-ad32-4bb4-b63d-f4f178f42b49"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Expand All @@ -18,6 +19,7 @@ TropicalNumbers = "b3a74e9c-7526-4576-a4eb-79c0d4c32334"
[compat]
CUDA = "4"
DocStringExtensions = "0.8.6, 0.9"
GenericTensorNetworks = "1"
OMEinsum = "0.7"
PrecompileTools = "1"
Requires = "1"
Expand Down
33 changes: 23 additions & 10 deletions src/Core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,16 @@ Probabilistic modeling with a tensor network.
### Fields
* `vars` are the degrees of freedom in the tensor network.
* `code` is the tensor network contraction pattern.
* `tensors` are the tensors fed into the tensor network.
* `tensors` are the tensors fed into the tensor network, the leading tensors are unity tensors associated with `mars`.
* `evidence` is a dictionary used to specify degrees of freedom that are fixed to certain values.
* `mars` is a vector, each element is a vector of variables to compute marginal probabilities.
"""
struct TensorNetworkModel{LT, ET, MT <: AbstractArray}
vars::Vector{LT}
code::ET
tensors::Vector{MT}
evidence::Dict{LT, Int}
mars::Vector{Vector{LT}}
end

function Base.show(io::IO, tn::TensorNetworkModel)
Expand Down Expand Up @@ -85,13 +87,21 @@ end

"""
$(TYPEDSIGNATURES)
### Keyword Arguments
* `openvars` is the list of variables that remains in the output. If it is not empty, the return value will be a nonzero ranked tensor.
* `evidence` is a dictionary of evidences, the values are integers start counting from 0.
* `optimizer` is the tensor network contraction order optimizer, please check the package [`OMEinsumContractionOrders.jl`](https://github.com/TensorBFS/OMEinsumContractionOrders.jl) for available algorithms.
* `simplifier` is some strategies for speeding up the `optimizer`, please refer the same link above.
* `mars` is a list of marginal probabilities. It is all single variables by default, i.e. `[[1], [2], ..., [n]]`. One can also specify multi-variables, which may increase the computational complexity.
"""
function TensorNetworkModel(
model::UAIModel;
openvars = (),
evidence = Dict{Int,Int}(),
optimizer = GreedyMethod(),
simplifier = nothing
simplifier = nothing,
mars = [[i] for i=1:model.nvars]
)::TensorNetworkModel
return TensorNetworkModel(
1:(model.nvars),
Expand All @@ -100,7 +110,8 @@ function TensorNetworkModel(
openvars,
evidence,
optimizer,
simplifier
simplifier,
mars
)
end

Expand All @@ -114,15 +125,16 @@ function TensorNetworkModel(
openvars = (),
evidence = Dict{LT, Int}(),
optimizer = GreedyMethod(),
simplifier = nothing
simplifier = nothing,
mars = [[v] for v in vars]
)::TensorNetworkModel where {T, LT}
# The 1st argument of `EinCode` is a vector of vector of labels for specifying the input tensors,
# The 2nd argument of `EinCode` is a vector of labels for specifying the output tensor,
# e.g.
# `EinCode([[1, 2], [2, 3]], [1, 3])` is the EinCode for matrix multiplication.
rawcode = EinCode([[[var] for var in vars]..., [[factor.vars...] for factor in factors]...], collect(LT, openvars)) # labels for vertex tensors (unity tensors) and edge tensors
tensors = Array{T}[[ones(T, cards[i]) for i in 1:length(vars)]..., [t.vals for t in factors]...]
return TensorNetworkModel(collect(LT, vars), rawcode, tensors; evidence, optimizer, simplifier)
rawcode = EinCode([mars..., [[factor.vars...] for factor in factors]...], collect(LT, openvars)) # labels for vertex tensors (unity tensors) and edge tensors
tensors = Array{T}[[ones(T, [cards[i] for i in mar]...) for mar in mars]..., [t.vals for t in factors]...]
return TensorNetworkModel(collect(LT, vars), rawcode, tensors; evidence, optimizer, simplifier, mars)
end

"""
Expand All @@ -134,15 +146,16 @@ function TensorNetworkModel(
tensors::Vector{<:AbstractArray};
evidence = Dict{LT, Int}(),
optimizer = GreedyMethod(),
simplifier = nothing
simplifier = nothing,
mars = [[v] for v in vars]
)::TensorNetworkModel where {LT}
# `optimize_code` optimizes the contraction order of a raw tensor network without a contraction order specified.
# The 1st argument is the contraction pattern to be optimized (without contraction order).
# The 2nd arugment is the size dictionary, which is a label-integer dictionary.
# The 3rd and 4th arguments are the optimizer and simplifier that configures which algorithm to use and simplify.
size_dict = OMEinsum.get_size_dict(getixsv(rawcode), tensors)
code = optimize_code(rawcode, size_dict, optimizer, simplifier)
TensorNetworkModel(collect(LT, vars), code, tensors, evidence)
TensorNetworkModel(collect(LT, vars), code, tensors, evidence, mars)
end

"""
Expand All @@ -159,7 +172,7 @@ Get the cardinalities of variables in this tensor network.
"""
function get_cards(tn::TensorNetworkModel; fixedisone = false)::Vector
vars = get_vars(tn)
[fixedisone && haskey(tn.evidence, vars[k]) ? 1 : length(tn.tensors[k]) for k in 1:length(vars)]
[fixedisone && haskey(tn.evidence, vars[k]) ? 1 : length(tn.tensors[k]) for k in eachindex(vars)]
end

chevidence(tn::TensorNetworkModel, evidence) = TensorNetworkModel(tn.vars, tn.code, tn.tensors, evidence)
Expand Down
1 change: 1 addition & 0 deletions src/TensorInference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ include("sampling.jl")
using Requires
function __init__()
@require CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" include("cuda.jl")
@require GenericTensorNetworks = "3521c873-ad32-4bb4-b63d-f4f178f42b49" include("generictensornetworks.jl")
end

# import PrecompileTools
Expand Down
44 changes: 44 additions & 0 deletions src/generictensornetworks.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
using .GenericTensorNetworks: generate_tensors, GraphProblem, flavors, labels

export probabilistic_model

"""
$TYPEDSIGNATURES
Convert a constraint satisfiability problem (or energy model) to a probabilistic model.
### Arguments
* `problem` is a `GraphProblem` instance in [`GenericTensorNetworks`](https://github.com/QuEraComputing/GenericTensorNetworks.jl).
* `β` is the inverse temperature.
"""
function TensorInference.TensorNetworkModel(problem::GraphProblem, β::Real; evidence::Dict=Dict{Int,Int}(),
optimizer=GreedyMethod(), simplifier=nothing, mars=[[l] for l in labels(problem)])
ixs = getixsv(problem.code)
iy = getiyv(problem.code)
lbs = labels(problem)
nflavors = length(flavors(problem))
# generate tensors for x = e^β
tensors = generate_tensors(exp(β), problem)
factors = [Factor((ix...,), t) for (ix, t) in zip(ixs, tensors)]
return TensorNetworkModel(lbs, fill(nflavors, length(lbs)), factors; openvars=iy, evidence, optimizer, simplifier, mars)
end
function TensorInference.MMAPModel(problem::GraphProblem, β::Real;
queryvars,
evidence = Dict{labeltype(problem.code), Int}(),
optimizer = GreedyMethod(), simplifier = nothing,
marginalize_optimizer = GreedyMethod(), marginalize_simplifier = nothing
)::MMAPModel
ixs = getixsv(problem.code)
iy = getiyv(problem.code)
nflavors = length(flavors(problem))
# generate tensors for x = e^β
tensors = generate_tensors(exp(β), problem)
factors = [Factor((ix...,), t) for (ix, t) in zip(ixs, tensors)]
lbs = labels(problem)
return MMAPModel(lbs, fill(nflavors, length(lbs)), factors; queryvars, openvars=iy, evidence,
optimizer, simplifier,
marginalize_optimizer, marginalize_simplifier)
end

@info "`TensorInference` loaded `GenericTensorNetworks` extension successfully,
`TensorNetworkModel` and `MMAPModel` can be used for converting a `GraphProblem` to a probabilistic model now."
2 changes: 2 additions & 0 deletions src/map.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ $(TYPEDSIGNATURES)
Returns the largest log-probability and the most probable configuration.
"""
function most_probable_config(tn::TensorNetworkModel; usecuda = false)::Tuple{Real, Vector}
expected_mars = [[l] for l in get_vars(tn)]
@assert tn.mars[1:length(expected_mars)] == expected_mars "To get the the most probable configuration, the leading elements of `tn.vars` must be `$expected_mars`"
vars = get_vars(tn)
tensors = map(t -> Tropical.(log.(t)), adapt_tensors(tn; usecuda, rescale = false))
logp, grads = cost_and_gradient(tn.code, tensors)
Expand Down
5 changes: 2 additions & 3 deletions src/mar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,12 @@ Returns the marginal probability distribution of variables.
One can use `get_vars(tn)` to get the full list of variables in this tensor network.
"""
function marginals(tn::TensorNetworkModel; usecuda = false, rescale = true)::Vector
vars = get_vars(tn)
# sometimes, the cost can overflow, then we need to rescale the tensors during contraction.
cost, grads = cost_and_gradient(tn.code, adapt_tensors(tn; usecuda, rescale))
@debug "cost = $cost"
if rescale
return LinearAlgebra.normalize!.(getfield.(grads[1:length(vars)], :normalized_value), 1)
return LinearAlgebra.normalize!.(getfield.(grads[1:length(tn.mars)], :normalized_value), 1)
else
return LinearAlgebra.normalize!.(grads[1:length(vars)], 1)
return LinearAlgebra.normalize!.(grads[1:length(tn.mars)], 1)
end
end
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
GenericTensorNetworks = "3521c873-ad32-4bb4-b63d-f4f178f42b49"
KaHyPar = "2a6221f6-aa48-11e9-3542-2d9e0ef01880"
OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Expand Down
19 changes: 19 additions & 0 deletions test/generictensornetworks.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using Test
using GenericTensorNetworks, TensorInference

@testset "marginals" begin
# compute the probability
β = 2.0
g = GenericTensorNetworks.Graphs.smallgraph(:petersen)
problem = IndependentSet(g)
model = TensorNetworkModel(problem, β; mars=[[2, 3]])
mars = marginals(model)[1]
problem2 = IndependentSet(g; openvertices=[2,3])
mars2 = TensorInference.normalize!(GenericTensorNetworks.solve(problem2, PartitionFunction(β)), 1)
@test mars mars2

# mmap
model = MMAPModel(problem, β; queryvars=[1,4])
logp, config = most_probable_config(model)
@test config == [0, 0]
end
55 changes: 55 additions & 0 deletions test/mar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,58 @@ end
end
end
end

@testset "joint marginal" begin
model = TensorInference.read_model_from_string("""MARKOV
8
2 2 2 2 2 2 2 2
8
1 0
2 1 0
1 2
2 3 2
2 4 2
3 5 3 1
2 6 5
3 7 5 4
2
0.01
0.99
4
0.05 0.01
0.95 0.99
2
0.5
0.5
4
0.1 0.01
0.9 0.99
4
0.6 0.3
0.4 0.7
8
1 1 1 0
0 0 0 1
4
0.98 0.05
0.02 0.95
8
0.9 0.7 0.8 0.1
0.1 0.3 0.2 0.9
""")
n = 10000
tnet = TensorNetworkModel(model; mars=[[2, 3], [3, 4]])
mars = marginals(tnet)
tnet23 = TensorNetworkModel(model; openvars=[2,3])
tnet34 = TensorNetworkModel(model; openvars=[3,4])
@test mars[1] probability(tnet23)
@test mars[2] probability(tnet34)
end
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ end
include("sampling.jl")
end

@testset "generic tensor networks" begin
include("generictensornetworks.jl")
end

using CUDA
if CUDA.functional()
include("cuda.jl")
Expand Down

0 comments on commit b292dd4

Please sign in to comment.