-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathPotts.jl
203 lines (142 loc) · 4.24 KB
/
Potts.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
using StatsBase
using Distributions
using DelimitedFiles
using Combinatorics: with_replacement_combinations
import LinearAlgebra
import LinearAlgebra: diag
import Statistics: mean
mutable struct FactorGraph{}
order_list::Array{Int64,1}
variable_count::Int
n_alphabets::Int
terms::Dict{Tuple,Array} # TODO, would be nice to have a stronger tuple type
#TODO In the sanity check make sure that an kth order hyper edge is mapped to a k rank tensor
#variable_names::Union{Vector{String}, Nothing}
#FactorGraph(a,b,c,d,e) = check_model_data(a,b,c,d,e) ? new(a,b,c,d,e) : error("generic init problem")
end
permutations(items, order::Int; asymmetric::Bool = false) = sort(permutations([], items, order, asymmetric))
function permutations(partial_perm::Array{Any,1}, items, order::Int, asymmetric::Bool)
"""
All possible permutations of a given size.
If asymmetric is false, then it returns combinations of items of the given order
If asymmetric is true it returns all possible tuples of the size given by order from items
"""
if order == 0
return [tuple(partial_perm...)]
else
perms = []
for item in items
if !asymmetric && length(partial_perm) > 0
if partial_perm[end] >= item
continue
end
end
perm = permutations(vcat(partial_perm, item), items, order-1, asymmetric)
append!(perms, perm)
end
return perms
end
end
function raw_sampler_potts(H::FactorGraph, n_samples::Int, centered::Bool)
"""
Given the FactorGraph, return samples according to its Gibbs distribution
"""
n = H.variable_count
q = H.n_alphabets
n_config = q^n
configs = [ digits(i,base=q, pad=n) .+ 1 for i = 0:n_config-1]
weights = [ exp(Energy_Potts(K, H, centered)) for K in configs ]
#print(configs, weights/sum(weights))
raw_samples = wsample(configs, weights, n_samples)
return raw_samples
end
function Energy_Potts(state::Array{Int64, 1},H::FactorGraph, cent::Bool)
"""
Given a state and a FactorGraph, return its energy
"""
q = H.n_alphabets
b = -1.0/q
a = 1.0 -(1.0/q)
if !(cent)
E = 0.0
for (e, theta) in H.terms
edge=Any[]
[push!(edge, state[j] ) for j in e]
E += theta[edge...]
end
return E
end
if cent
E = 0.0
for (e, theta) in H.terms
clrs=Any[]
r = length(e) #order of interaction
alphabet_keys = permutations(Array(1:q), r, asymmetric=true) #No need to generate this everytime
[push!(clrs, state[j] ) for j in e]
ct = Tuple(clrs)
[E += a^(sum( ct.==c )) * b^(r - sum(ct.==c)) *theta[c...] for c in alphabet_keys]
end
return E
end
end
function TVD(truth::Dict{}, est::Dict{}, n_samples::Int)
"""
Total variation distance between two distributions.
"""
s = 0.0
for (k ,v) in est
if haskey(truth, k)
s+= abs( v - truth[k])
else
s+= v
end
end
for (k,v) in truth
if !haskey(est, k)
s+=v
end
end
return s/(2*n_samples)
end
function conditional_energy(u::Int, state::Array{Int64, 1},H::FactorGraph, cent::Bool)
"""
Given a state and a FactorGraph, return its energy
"""
q = H.n_alphabets
b = -1.0/q
a = 1.0 -(1.0/q)
if !(cent)
E = 0.0
for (e, theta) in H.terms
if u in e
edge=Any[]
[push!(edge, state[j] ) for j in e]
E += theta[edge...]
end
end
return E
end
if cent
E = 0.0
for (e, theta) in H.terms
if u in e
clrs=Any[]
r = length(e) #order of interaction
alphabet_keys = permutations(Array(1:q), r, asymmetric=true) #No need to generate this everytime
[push!(clrs, state[j] ) for j in e]
ct = Tuple(clrs)
[E += a^(sum( ct.==c )) * b^(r - sum(ct.==c)) *theta[c...] for c in alphabet_keys]
end
end
return E
end
end
function pth_order_tensor(r)
a = Tuple([2 for i in 1:r])
M = zeros(a)
for i in CartesianIndices(M)
s = prod(2*[Tuple(i)...] .- 3)
M[i] = s
end
return M
end