diff --git a/src/ITensorNetworks.jl b/src/ITensorNetworks.jl index 56ab768f..8db1ae36 100644 --- a/src/ITensorNetworks.jl +++ b/src/ITensorNetworks.jl @@ -29,6 +29,7 @@ include("caches/abstractbeliefpropagationcache.jl") include("caches/beliefpropagationcache.jl") include("formnetworks/abstractformnetwork.jl") include("formnetworks/bilinearformnetwork.jl") +include("formnetworks/maximise_bilinearformnetwork.jl") include("formnetworks/quadraticformnetwork.jl") include("contraction_tree_to_graph.jl") include("gauging.jl") diff --git a/src/formnetworks/maximise_bilinearformnetwork.jl b/src/formnetworks/maximise_bilinearformnetwork.jl new file mode 100644 index 00000000..d72b4471 --- /dev/null +++ b/src/formnetworks/maximise_bilinearformnetwork.jl @@ -0,0 +1,133 @@ +using NamedGraphs.NamedGraphGenerators: named_grid +using NamedGraphs: NamedEdge +using ITensors: ITensors, ITensor, contract, dag +using Graphs: is_tree +using NamedGraphs.PartitionedGraphs: partitioned_graph, partitionedges, partitionvertex, partitionvertices +using NamedGraphs.GraphsExtensions: bfs_tree, leaf_vertices, post_order_dfs_edges, src, dst, vertices +using NDTensors: Algorithm +using Dictionaries +using LinearAlgebra: norm_sqr + +default_solver_algorithm() = "orthogonalize" +default_solver_kwargs() = (; niters = 25, nsites = 1, tolerance = 1e-10, normalize = true, maxdim = nothing, cutoff = nothing) + +#TODO: Come up with reasonable sequence for non-trees +function blf_update_sequence(g::AbstractGraph; nsites::Int64=1) + @assert is_tree(g) + if nsites == 1 || nsites == 2 + es = post_order_dfs_edges(g, first(leaf_vertices(g))) + vs = [[src(e), dst(e)] for e in es] + regions = nsites == 2 ? vs : [[v] for v in unique(reduce(vcat, vs))] + return vcat(regions, reverse(reverse.(regions))) + else + error("Nsites > 2 sequences not currently supported") + end +end + +#TODO: biorthogonal updater and gauging +function blf_updater(alg::Algorithm"orthogonalize", xAy_bpc::AbstractBeliefPropagationCache, y::AbstractITensorNetwork, prev_region::Vector, region::Vector) + path = gauge_path(y, prev_region, region) + y = gauge_walk(alg, y, path) + verts = unique(vcat(src.(path), dst.(path))) + factors = [dag(y[v]) for v in verts] + xAy_bpc = update_factors(xAy_bpc, Dictionary([(v, "ket") for v in verts], factors)) + pe_path = partitionedges(partitioned_tensornetwork(xAy_bpc), [NamedEdge((src(e), "ket") => (dst(e), "ket")) for e in path]) + xAy_bpc = update(Algorithm("bp"), xAy_bpc, pe_path; message_update_function_kwargs = (; normalize = false)) + return xAy_bpc, y +end + +function blf_extracter(xAy_bpc::AbstractBeliefPropagationCache, region::Vector) + return environment(xAy_bpc, [(v, "ket") for v in region]) +end + +function blf_inserter(∂xAy_bpc_∂r::Vector{ITensor}, xAy_bpc::AbstractBeliefPropagationCache, y::AbstractITensorNetwork, region::Vector; normalize, maxdim, cutoff) + yr = contract(∂xAy_bpc_∂r; sequence = "automatic") + if length(region) == 1 + v = only(region) + if normalize + yr /= sqrt(norm_sqr(yr)) + end + y[v] = yr + elseif length(region) == 2 + v1, v2 = first(region), last(region) + linds, cind = uniqueinds(y[v1], y[v2]), commonind(y[v1], y[v2]) + yv1, yv2 = factorize(yr, linds; ortho = "left", tags=tags(cind), cutoff, maxdim) + if normalize + yv2 /= sqrt(norm_sqr(yv2)) + end + y[v1], y[v2] = yv1, yv2 + else + error("Updates with regions bigger than 2 not currently supported") + end + vertices = [(v, "ket") for v in region] + factors = [y[v] for v in region] + xAy_bpc = update_factors(xAy_bpc, Dictionary(vertices, factors)) + return y, xAy_bpc +end + +function blf_costfunction(xAy::AbstractBeliefPropagationCache, region) + verts = [(v, "ket") for v in region] + return contract([environment(xAy, verts); factors(xAy, verts)]; sequence = "automatic")[] +end + +#Optimize over y to maximize * / based on a designated partitioning of the bilinearform +#For now, y should be a tree tensor network and should be a tree under the partitioning +function maximize_bilinearform( + alg::Algorithm"orthogonalize", + xAy::BilinearFormNetwork, + y::ITensorNetwork = dag(ket_network(xAy)), + partition_verts = group(v -> first(v), vertices(xAy)); + updater = blf_updater, + extracter = blf_extracter, + inserter = blf_inserter, + costfunction = blf_costfunction, + sequence = blf_update_sequence, + normalize::Bool = true, + niters::Int64 = 25, + nsites::Int64 = 1, + tolerance = nothing, + maxdim = nothing, + cutoff = nothing) + + #These assertions can easily be lessened in the future + @assert is_tree(y) + xAy_bpc = BeliefPropagationCache(xAy, partition_verts) + @assert is_tree(partitioned_graph(xAy_bpc)) + seq = sequence(y; nsites) + + prev_region = collect(vertices(y)) + cs = zeros(ComplexF64, (niters, length(seq))) + for i in 1:niters + for (j, region) in enumerate(seq) + xAy_bpc, y = updater(alg, xAy_bpc, y, prev_region, region) + ∂xAy_bpc_∂r = extracter(xAy_bpc, region) + y, xAy_bpc = inserter(∂xAy_bpc_∂r, xAy_bpc, y, region; normalize, maxdim, cutoff) + cs[i, j] = costfunction(xAy_bpc, region) + prev_region = region + end + if i >= 2 && (abs(sum(cs[i, :]) - sum(cs[i-1, :]))) / length(seq) <= tolerance + return xAy_bpc, dag(y) + end + end + + return xAy_bpc, dag(y) +end + +function Base.truncate(x::AbstractITensorNetwork; maxdim_init::Int64, kwargs...) + y = ITensorNetwork(v -> inds -> delta(inds), siteinds(x); link_space = maxdim_init) + xIy = BilinearFormNetwork(x, y) + xIy_bpc, y_out = maximize_bilinearform(xIy, y; kwargs...) + return y_out +end + +function ITensors.apply(A::AbstractITensorNetwork, x::AbstractITensorNetwork; maxdim_init::Int64, kwargs...) + y = ITensorNetwork(v -> inds -> delta(inds), siteinds(x); link_space = maxdim_init) + xAy = BilinearFormNetwork(A, x, y) + xAy_bpc, y_out = maximize_bilinearform(xAy, y; kwargs...) + return y_out +end + +function maximize_bilinearform(xAy::BilinearFormNetwork, args...; alg = default_solver_algorithm(), solver_kwargs = default_solver_kwargs()) + return maximize_bilinearform(Algorithm(alg), xAy, args...; solver_kwargs...) +end + diff --git a/src/formnetworks/maximise_bilinearformnetwork_V2.jl b/src/formnetworks/maximise_bilinearformnetwork_V2.jl new file mode 100644 index 00000000..5bef1e34 --- /dev/null +++ b/src/formnetworks/maximise_bilinearformnetwork_V2.jl @@ -0,0 +1,23 @@ +@kwdef mutable struct FittingProblem{State, OverlapNetwork} + state::State + overlapnetwork::OverlapNetwork + squared_scalar::Number = 0 +end + +squared_scalar(F::FittingProblem) = F.squared_scalar +state(F::FittingProblem) = F.state +overlapnetwork(F::FittingProblem) = F.overlapnetwork + +function set(F::FittingProblem; state = state(F), overlapnetwork = overlapnetwork(F), squared_scalar = squared_scalar(F)) + return FittingProblem(; state, linearformnetwork, squared_scalar) +end + +function fit_tensornetwork(tn::AbstractITensorNetwork, init_state::AbstractITensorNetwork, vertex_partitioning) + overlap_bpc = BeliefPropagationCache(inner_network(tn, init_state), vertex_partitioning) + init_prob = FittingProblem(; state = copy(init_state), overlapnetwork = overlap_bpc) + common_sweep_kwargs = (; nsites, outputlevel, updater_kwargs, inserter_kwargs) + kwargs_array = [(; common_sweep_kwargs..., sweep = s) for s in 1:nsweeps] + sweep_iter = sweep_iterator(init_prob, kwargs_array) + converged_prob = alternating_update(sweep_iter; outputlevel, kws...) + return squared_scalar(converged_prob), state(converged_prob) +end \ No newline at end of file diff --git a/test/test_maximisebilinearform.jl b/test/test_maximisebilinearform.jl new file mode 100644 index 00000000..4730dd57 --- /dev/null +++ b/test/test_maximisebilinearform.jl @@ -0,0 +1,50 @@ +@eval module $(gensym()) +using ITensorNetworks: BilinearFormNetwork, ITensorNetwork, random_tensornetwork, siteinds, subgraph, ttn, inner, truncate, maximize_bilinearform, union_all_inds +using ITensorNetworks.ModelHamiltonians: heisenberg +using Graphs: vertices +using NamedGraphs.NamedGraphGenerators: named_grid, named_comb_tree +using SplitApplyCombine: group +using StableRNGs: StableRNG +using TensorOperations: TensorOperations +using Test: @test, @test_broken, @testset +using ITensors: apply, dag, delta, prime + + +@testset "Maximise BilinearForm" for elt in ( + Float32, Float64, Complex{Float32}, Complex{Float64} + ) + begin + + rng = StableRNG(1234) + + g = named_comb_tree((3,2)) + s = siteinds("S=1/2", g) + + #One-site truncation + a = random_tensornetwork(rng, elt, s; link_space = 3) + b = truncate(a; maxdim_init = 3) + f = inner(a, b; alg = "exact") / sqrt(inner(a, a; alg = "exact") * inner(b, b; alg = "exact")) + @test f * conj(f) ≈ 1.0 atol = 10*eps(real(elt)) + + #Two-site truncation + a = random_tensornetwork(rng, elt, s; link_space = 3) + b = truncate(a; maxdim_init = 1, solver_kwargs= (; maxdim = 3, cutoff = 1e-16, nsites = 2, tolerance = 1e-8)) + f = inner(a, b; alg = "exact") / sqrt(inner(a, a; alg = "exact") * inner(b, b; alg = "exact")) + @test f * conj(f) ≈ 1.0 atol = 10*eps(real(elt)) + + #One-site apply (no normalization) + a = random_tensornetwork(rng, elt, s; link_space = 2) + H = ITensorNetwork(ttn(heisenberg(g), s)) + Ha = apply(H, a; maxdim_init = 4, solver_kwargs = (; niters = 20, nsites = 1, tolerance = 1e-8, normalize = false)) + @test inner(Ha, a; alg = "exact") / inner(a, H, a; alg = "exact") ≈ 1.0 atol = 10*eps(real(elt)) + + #Two-site apply (no normalization) + a = random_tensornetwork(rng, elt, s; link_space = 2) + H = ITensorNetwork(ttn(heisenberg(g), s)) + Ha = apply(H, a; maxdim_init = 1, solver_kwargs= (; maxdim = 4, cutoff = 1e-16, nsites = 2, tolerance = 1e-8, normalize = false)) + @test inner(Ha, a; alg = "exact") / inner(a, H, a; alg = "exact") ≈ 1.0 atol = 10*eps(real(elt)) + + end +end + +end