Skip to content

Commit 3290f9e

Browse files
authored
Implement Belief propagation as a new inference backend (#97)
* clean up * update * vars -> nvars * update * update * update * update * implement marginals * fix docs * update * Change julia version to 1.10 in the compat file * add uai test * format document and fix tests * fix docstring * clean up
1 parent f5aa742 commit 3290f9e

23 files changed

+440
-260
lines changed

Project.toml

+2-6
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,10 @@ authors = ["Jin-Guo Liu", "Martin Roa Villescas"]
44
version = "0.5.0"
55

66
[deps]
7-
Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
87
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
98
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
109
OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922"
1110
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
12-
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1311
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
1412
ProblemReductions = "899c297d-f7d2-4ebf-8815-a35996def416"
1513
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
@@ -22,15 +20,13 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
2220
TensorInferenceCUDAExt = "CUDA"
2321

2422
[compat]
25-
Artifacts = "1"
2623
CUDA = "4, 5"
2724
DocStringExtensions = "0.8.6, 0.9"
2825
LinearAlgebra = "1"
29-
OMEinsum = "0.8"
26+
OMEinsum = "0.8.7"
3027
Pkg = "1"
31-
PrecompileTools = "1"
3228
PrettyTables = "2"
3329
ProblemReductions = "0.3"
3430
StatsBase = "0.34"
3531
TropicalNumbers = "0.5.4, 0.6"
36-
julia = "1.9"
32+
julia = "1.10"

docs/src/api/public.md

+4
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ RescaledArray
4343
TensorNetworkModel
4444
ArtifactProblemSpec
4545
UAIModel
46+
BeliefPropgation
4647
```
4748

4849
## Functions
@@ -56,6 +57,7 @@ marginals
5657
maximum_logp
5758
most_probable_config
5859
probability
60+
belief_propagate
5961
dataset_from_artifact
6062
problem_from_artifact
6163
read_model
@@ -69,4 +71,6 @@ sample
6971
update_evidence!
7072
update_temperature
7173
random_matrix_product_state
74+
random_matrix_product_uai
75+
random_tensor_train_uai
7276
```

docs/src/tensor-networks.md

+19
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,13 @@ Some of these have been implemented in the
205205
[OMEinsum](https://github.com/under-Peter/OMEinsum.jl) package. Please check
206206
[Performance Tips](@ref) for more details.
207207

208+
## Belief propagation
209+
210+
Belief propagation[^Yedidia2003] is a message passing algorithm that can be used to compute the marginals of a probabilistic graphical model. It has close connections with the tensor networks. It can be viewed as a way to gauge the tensor networks[^Tindall2023], and can be combined with tensor networks to achieve better performance[^Wang2024].
211+
212+
Belief propagation is an approximate method, and the quality of the approximation can be improved by the loop series expansion[^Evenbly2024].
213+
214+
208215
## References
209216

210217
[^Orus2014]:
@@ -227,3 +234,15 @@ Some of these have been implemented in the
227234

228235
[^Liu2023]:
229236
Liu J G, Gao X, Cain M, et al. Computing solution space properties of combinatorial optimization problems via generic tensor networks[J]. SIAM Journal on Scientific Computing, 2023, 45(3): A1239-A1270.
237+
238+
[^Yedidia2003]:
239+
Yedidia, J.S., Freeman, W.T., Weiss, Y., 2003. Understanding belief propagation and its generalizations, in: Exploring Artificial Intelligence in the New Millennium. Morgan Kaufmann Publishers Inc., San Francisco, CA, USA, pp. 239–269.
240+
241+
[^Wang2024]:
242+
Wang, Y., Zhang, Y.E., Pan, F., Zhang, P., 2024. Tensor Network Message Passing. Phys. Rev. Lett. 132, 117401. https://doi.org/10.1103/PhysRevLett.132.117401
243+
244+
[^Tindall2023]:
245+
Tindall, J., Fishman, M.T., 2023. Gauging tensor networks with belief propagation. SciPost Phys. 15, 222. https://doi.org/10.21468/SciPostPhys.15.6.222
246+
247+
[^Evenbly2024]:
248+
Evenbly, G., Pancotti, N., Milsted, A., Gray, J., Chan, G.K.-L., 2024. Loop Series Expansions for Tensor Networks. https://doi.org/10.48550/arXiv.2409.03108

examples/hard-core-lattice-gas/main.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ mars = marginals(pmodel)
6262
show_graph(SimpleGraph(graph), sites; vertex_colors=[(b = mars[[i]][2]; (1-b, 1-b, 1-b)) for i in 1:nv(graph)], texts=fill("", nv(graph)))
6363
# The can see the sites at the corner is more likely to be occupied.
6464
# To obtain two-site correlations, one can set the variables to query marginal probabilities manually.
65-
pmodel2 = TensorNetworkModel(problem, β; mars=[[e.src, e.dst] for e in edges(graph)])
65+
pmodel2 = TensorNetworkModel(problem, β; unity_tensors_labels = [[e.src, e.dst] for e in edges(graph)])
6666
mars = marginals(pmodel2);
6767

6868
# We show the probability that both sites on an edge are not occupied

ext/TensorInferenceCUDAExt.jl

+1-4
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,14 @@
11
module TensorInferenceCUDAExt
22
using CUDA: CuArray
33
import CUDA
4-
import TensorInference: match_arraytype, keep_only!, onehot_like, togpu
4+
import TensorInference: keep_only!, onehot_like, togpu
55

66
function onehot_like(A::CuArray, j)
77
mask = zero(A)
88
CUDA.@allowscalar mask[j] = one(eltype(mask))
99
return mask
1010
end
1111

12-
# NOTE: this interface should be in OMEinsum
13-
match_arraytype(::Type{<:CuArray{T, N}}, target::AbstractArray{T, N}) where {T, N} = CuArray(target)
14-
1512
function keep_only!(x::CuArray{T}, j) where T
1613
CUDA.@allowscalar hotvalue = x[j]
1714
fill!(x, zero(T))

src/Core.jl

+19-82
Original file line numberDiff line numberDiff line change
@@ -45,18 +45,18 @@ $(TYPEDEF)
4545
Probabilistic modeling with a tensor network.
4646
4747
### Fields
48-
* `vars` are the degrees of freedom in the tensor network.
48+
* `nvars` are the number of variables in the tensor network.
4949
* `code` is the tensor network contraction pattern.
50-
* `tensors` are the tensors fed into the tensor network, the leading tensors are unity tensors associated with `mars`.
50+
* `tensors` are the tensors fed into the tensor network, the leading tensors are unity tensors associated with `unity_tensors_labels`.
5151
* `evidence` is a dictionary used to specify degrees of freedom that are fixed to certain values.
52-
* `mars` is a vector, each element is a vector of variables to compute marginal probabilities.
52+
* `unity_tensors_idx` is a vector of indices of the unity tensors in the `tensors` array. Unity tensors are dummy tensors used to obtain the marginal probabilities.
5353
"""
54-
struct TensorNetworkModel{LT, ET, MT <: AbstractArray}
55-
vars::Vector{LT}
54+
struct TensorNetworkModel{ET, MT <: AbstractArray}
55+
nvars::Int
5656
code::ET
5757
tensors::Vector{MT}
58-
evidence::Dict{LT, Int}
59-
mars::Vector{Vector{LT}}
58+
evidence::Dict{Int, Int}
59+
unity_tensors_idx::Vector{Int}
6060
end
6161

6262
"""
@@ -78,7 +78,7 @@ end
7878

7979
function Base.show(io::IO, tn::TensorNetworkModel)
8080
open = getiyv(tn.code)
81-
variables = join([string_var(var, open, tn.evidence) for var in tn.vars], ", ")
81+
variables = join([string_var(var, open, tn.evidence) for var in get_vars(tn)], ", ")
8282
tc, sc, rw = contraction_complexity(tn)
8383
println(io, "$(typeof(tn))")
8484
println(io, "variables: $variables")
@@ -110,102 +110,42 @@ $(TYPEDSIGNATURES)
110110
* `evidence` is a dictionary of evidences, the values are integers start counting from 0.
111111
* `optimizer` is the tensor network contraction order optimizer, please check the package [`OMEinsumContractionOrders.jl`](https://github.com/TensorBFS/OMEinsumContractionOrders.jl) for available algorithms.
112112
* `simplifier` is some strategies for speeding up the `optimizer`, please refer the same link above.
113-
* `mars` is a list of marginal probabilities. It is all single variables by default, i.e. `[[1], [2], ..., [n]]`. One can also specify multi-variables, which may increase the computational complexity.
113+
* `unity_tensors_labels` is a list of labels for the unity tensors. It is all single variables by default, i.e. `[[1], [2], ..., [n]]`. One can also specify multi-variables, which may increase the computational complexity.
114114
"""
115115
function TensorNetworkModel(
116-
model::UAIModel;
116+
model::UAIModel{ET, FT};
117117
openvars = (),
118118
evidence = Dict{Int,Int}(),
119119
optimizer = GreedyMethod(),
120120
simplifier = nothing,
121-
mars = [[i] for i=1:model.nvars]
122-
)::TensorNetworkModel
123-
return TensorNetworkModel(
124-
1:(model.nvars),
125-
model.cards,
126-
model.factors;
127-
openvars,
128-
evidence,
129-
optimizer,
130-
simplifier,
131-
mars
132-
)
133-
end
134-
135-
"""
136-
$(TYPEDSIGNATURES)
137-
"""
138-
function TensorNetworkModel(
139-
vars::AbstractVector{LT},
140-
cards::AbstractVector{Int},
141-
factors::Vector{<:Factor{T}};
142-
openvars = (),
143-
evidence = Dict{LT, Int}(),
144-
optimizer = GreedyMethod(),
145-
simplifier = nothing,
146-
mars = [[v] for v in vars]
147-
)::TensorNetworkModel where {T, LT}
148-
# The 1st argument of `EinCode` is a vector of vector of labels for specifying the input tensors,
149-
# The 2nd argument of `EinCode` is a vector of labels for specifying the output tensor,
150-
# e.g.
151-
# `EinCode([[1, 2], [2, 3]], [1, 3])` is the EinCode for matrix multiplication.
152-
rawcode = EinCode([mars..., [[factor.vars...] for factor in factors]...], collect(LT, openvars)) # labels for vertex tensors (unity tensors) and edge tensors
153-
tensors = Array{T}[[ones(T, [cards[i] for i in mar]...) for mar in mars]..., [t.vals for t in factors]...]
154-
return TensorNetworkModel(collect(LT, vars), rawcode, tensors; evidence, optimizer, simplifier, mars)
155-
end
156-
157-
"""
158-
$(TYPEDSIGNATURES)
159-
"""
160-
function TensorNetworkModel(
161-
vars::AbstractVector{LT},
162-
rawcode::EinCode,
163-
tensors::Vector{<:AbstractArray};
164-
evidence = Dict{LT, Int}(),
165-
optimizer = GreedyMethod(),
166-
simplifier = nothing,
167-
mars = [[v] for v in vars]
168-
)::TensorNetworkModel where {LT}
121+
unity_tensors_labels = [[i] for i=1:model.nvars]
122+
) where {ET, FT}
169123
# `optimize_code` optimizes the contraction order of a raw tensor network without a contraction order specified.
170124
# The 1st argument is the contraction pattern to be optimized (without contraction order).
171125
# The 2nd arugment is the size dictionary, which is a label-integer dictionary.
172126
# The 3rd and 4th arguments are the optimizer and simplifier that configures which algorithm to use and simplify.
127+
rawcode = EinCode([unity_tensors_labels..., [[factor.vars...] for factor in model.factors]...], collect(Int, openvars)) # labels for vertex tensors (unity tensors) and edge tensors
128+
tensors = Array{ET}[[ones(ET, [model.cards[i] for i in lb]...) for lb in unity_tensors_labels]..., [t.vals for t in model.factors]...]
173129
size_dict = OMEinsum.get_size_dict(getixsv(rawcode), tensors)
174130
code = optimize_code(rawcode, size_dict, optimizer, simplifier)
175-
TensorNetworkModel(collect(LT, vars), code, tensors, evidence, mars)
176-
end
177-
178-
"""
179-
$(TYPEDSIGNATURES)
180-
"""
181-
function TensorNetworkModel(
182-
model::UAIModel{T}, code;
183-
evidence = Dict{Int,Int}(),
184-
mars = [[i] for i=1:model.nvars],
185-
vars = [1:model.nvars...]
186-
)::TensorNetworkModel where{T}
187-
@debug "constructing tensor network model from code"
188-
tensors = Array{T}[[ones(T, [model.cards[i] for i in mar]...) for mar in mars]..., [t.vals for t in model.factors]...]
189-
190-
return TensorNetworkModel(vars, code, tensors, evidence, mars)
131+
return TensorNetworkModel(model.nvars, code, tensors, evidence, collect(Int, 1:length(unity_tensors_labels)))
191132
end
192133

193134
"""
194135
$(TYPEDSIGNATURES)
195136
196137
Get the variables in this tensor network, they are also known as legs, labels, or degree of freedoms.
197138
"""
198-
get_vars(tn::TensorNetworkModel)::Vector = tn.vars
139+
get_vars(tn::TensorNetworkModel)::Vector = 1:tn.nvars
199140

200141
"""
201142
$(TYPEDSIGNATURES)
202143
203-
Get the cardinalities of variables in this tensor network.
144+
Get the ardinalities of variables in this tensor network.
204145
"""
205146
function get_cards(tn::TensorNetworkModel; fixedisone = false)::Vector
206-
vars = get_vars(tn)
207147
size_dict = OMEinsum.get_size_dict(getixsv(tn.code), tn.tensors)
208-
[fixedisone && haskey(tn.evidence, vars[k]) ? 1 : size_dict[vars[k]] for k in eachindex(vars)]
148+
[fixedisone && haskey(tn.evidence, k) ? 1 : size_dict[k] for k in 1:tn.nvars]
209149
end
210150

211151
chevidence(tn::TensorNetworkModel, evidence) = TensorNetworkModel(tn.vars, tn.code, tn.tensors, evidence)
@@ -250,7 +190,4 @@ Returns the contraction complexity of a tensor newtork model.
250190
"""
251191
function OMEinsum.contraction_complexity(tn::TensorNetworkModel)
252192
return contraction_complexity(tn.code, Dict(zip(get_vars(tn), get_cards(tn; fixedisone = true))))
253-
end
254-
255-
# adapt array type with the target array type
256-
match_arraytype(::Type{<:Array{T, N}}, target::AbstractArray{T, N}) where {T, N} = Array(target)
193+
end

src/RescaledArray.jl

+6-1
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,9 @@ end
4646
Base.size(arr::RescaledArray) = size(arr.normalized_value)
4747
Base.size(arr::RescaledArray, i::Int) = size(arr.normalized_value, i)
4848

49-
match_arraytype(::Type{<:RescaledArray{T, N, AT}}, target::AbstractArray{T, N}) where {T, N, AT} = rescale_array(match_arraytype(AT, target))
49+
function OMEinsum.get_output_array(xs::NTuple{N, RescaledArray{T}}, size, fillzero::Bool) where {N, T}
50+
return RescaledArray(zero(T), OMEinsum.get_output_array(getfield.(xs, :normalized_value), size, fillzero))
51+
end
52+
# The following two APIs are required by OMEinsum
53+
Base.fill!(r::RescaledArray, x) = (fill!(r.normalized_value, x ./ exp(r.log_factor)); r)
54+
Base.conj(r::RescaledArray) = RescaledArray(conj(r.log_factor), conj(r.normalized_value))

src/TensorInference.jl

+6-10
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ $(EXPORTS)
88
module TensorInference
99

1010
using OMEinsum, LinearAlgebra
11+
using OMEinsum: CacheTree, cached_einsum
1112
using DocStringExtensions, TropicalNumbers
1213
# The Tropical GEMM support
1314
using StatsBase
@@ -40,8 +41,11 @@ export MMAPModel
4041
# for ProblemReductions
4142
export update_temperature
4243

44+
# belief propagation
45+
export BeliefPropgation, belief_propagate
46+
4347
# utils
44-
export random_matrix_product_state
48+
export random_matrix_product_state, random_tensor_train_uai, random_matrix_product_uai
4549

4650
include("Core.jl")
4751
include("RescaledArray.jl")
@@ -51,14 +55,6 @@ include("map.jl")
5155
include("mmap.jl")
5256
include("sampling.jl")
5357
include("cspmodels.jl")
54-
55-
# import PrecompileTools
56-
# PrecompileTools.@setup_workload begin
57-
# # Putting some things in `@setup_workload` instead of `@compile_workload` can reduce the size of the
58-
# # precompile file and potentially make loading faster.
59-
# PrecompileTools.@compile_workload begin
60-
# include("../example/asia-network/main.jl")
61-
# end
62-
# end
58+
include("belief.jl")
6359

6460
end # module

0 commit comments

Comments
 (0)