Skip to content

Commit

Permalink
add multilevel for Tucker layer
Browse files Browse the repository at this point in the history
  • Loading branch information
DexuanZhou committed Nov 26, 2023
1 parent ae1a6d7 commit b006acd
Show file tree
Hide file tree
Showing 10 changed files with 106 additions and 75 deletions.
5 changes: 3 additions & 2 deletions examples/3D/Be.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@ bYlm = RYlmBasis(Ylmdegree)

totdegree = [30,30,30]
ν = [1,1,2]
MaxIters = [100,100,2000]
MaxIters = [100,100,200]
_spec = [spec[1:3], spec, spec]
wf_list, spec_list, spec1p_list, specAO_list, ps_list, st_list = wf_multilevel(Nel, Σ, nuclei, Dn, Pn, bYlm, _spec, totdegree, ν)
_TD = [ACEpsi.Tucker(5),ACEpsi.Tucker(6),ACEpsi.Tucker(7)]
wf_list, spec_list, spec1p_list, specAO_list, ps_list, st_list = wf_multilevel(Nel, Σ, nuclei, Dn, Pn, bYlm, _spec, totdegree, ν, _TD)

ham = SumH(nuclei)
sam = MHSampler(wf_list[1], Nel, nuclei, Δt = 0.5, burnin = 1000, nchains = 2000)
Expand Down
5 changes: 2 additions & 3 deletions examples/3D/He.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,8 @@ totdegree = [30,30,30]
ν = [1,1,2]
MaxIters = [100,100,500]
_spec = [spec[1:3], spec[1:4], spec]
#_spec = [spec[1:i] for i = 4:length(spec)]
#_spec = length(ν)>length(spec) ? reduce(vcat, [_spec, [spec[1:end] for i = 1:length(ν) - length(spec)]]) : _spec
wf_list, spec_list, spec1p_list, specAO_list, ps_list, st_list = wf_multilevel(Nel, Σ, nuclei, Dn, Pn, bYlm, _spec, totdegree, ν)
_TD = [ACEpsi.Tucker(5),ACEpsi.Tucker(6),ACEpsi.Tucker(7)]
wf_list, spec_list, spec1p_list, specAO_list, ps_list, st_list = wf_multilevel(Nel, Σ, nuclei, Dn, Pn, bYlm, _spec, totdegree, ν, _TD)

ham = SumH(nuclei)
sam = MHSampler(wf_list[1], Nel, nuclei, Δt = 0.5, burnin = 1000, nchains = 2000)
Expand Down
3 changes: 2 additions & 1 deletion examples/3D/ccpvdz_H10.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,11 @@ bYlm = RYlmBasis(Ylmdegree)
totdegree = [30, 30, 30]
ν = [1, 1, 1]
MaxIters = [150, 200, 200]
_TD = [ACEpsi.No_Decomposition(),ACEpsi.No_Decomposition(),ACEpsi.No_Decomposition()]
spec = [(n1 = 1, n2 = 1, l = 0), (n1 = 1, n2 = 2, l = 0), (n1 = 2, n2 = 1, l = 1)]
_spec = [spec[1:i] for i = 1:length(spec)]
_spec = length(ν)>length(spec) ? reduce(vcat, [_spec, [spec[1:end] for i = 1:length(ν) - length(spec)]]) : _spec
wf_list, spec_list, spec1p_list, specAO_list, ps_list, st_list = wf_multilevel(Nel, Σ, nuclei, Dn, Pn, bYlm, _spec, totdegree, ν)
wf_list, spec_list, spec1p_list, specAO_list, ps_list, st_list = wf_multilevel(Nel, Σ, nuclei, Dn, Pn, bYlm, _spec, totdegree, ν, _TD)

ham = SumH(nuclei)
sam = MHSampler(wf_list[1], Nel, nuclei, Δt = 0.5, burnin = 2000, nchains = 2000)
Expand Down
3 changes: 2 additions & 1 deletion examples/3D/slaterbasis_H10.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ totdegree = [30, 30, 30, 30, 30]
MaxIters = [50, 100, 100, 200,200]
_spec = [spec[1:i] for i = 1:length(spec)]
_spec = length(ν)>length(spec) ? reduce(vcat, [_spec, [spec[1:end] for i = 1:length(ν) - length(spec)]]) : _spec
wf_list, spec_list, spec1p_list, specAO_list, ps_list, st_list = wf_multilevel(Nel, Σ, nuclei, Dn, Pn, bYlm, _spec, totdegree, ν)
_TD = [ACEpsi.Tucker(5),ACEpsi.Tucker(6),ACEpsi.Tucker(7),ACEpsi.Tucker(7),ACEpsi.Tucker(7)]
wf_list, spec_list, spec1p_list, specAO_list, ps_list, st_list = wf_multilevel(Nel, Σ, nuclei, Dn, Pn, bYlm, _spec, totdegree, ν, _TD)

ham = SumH(nuclei)
sam = MHSampler(wf_list[1], Nel, nuclei, Δt = 0.5, burnin = 1000, nchains = 2000)
Expand Down
58 changes: 58 additions & 0 deletions src/experimental/mmultilevel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,61 @@ function mwf_multilevel(Nel::Int, Σ::Vector{Char}, nuclei::Vector{Nuc{T}}, Nbf:
end
return wf, spec, spec1p, _spec, ps, st
end


function mEmbeddingW!(ps, ps2, spec, spec2, spec1p, spec1p2, specAO, specAO2)
readable_spec = displayspec(spec, spec1p)
readable_spec2 = displayspec(spec2, spec1p2)
@assert size(ps.branch.bf.Pds.layer_1.hidden1.W, 1) == size(ps2.branch.bf.Pds.layer_1.hidden1.W, 1)
@assert size(ps.branch.bf.Pds.layer_1.hidden1.W, 2) size(ps2.branch.bf.Pds.layer_1.hidden1.W, 2)
@assert all(t in readable_spec2 for t in readable_spec)
@assert all(t in specAO2 for t in specAO)

# set all parameters to zero
for i in keys(ps.branch.bf.Pds)
ps2.branch.bf.Pds[i].hidden1.W .= 0.0
end
if :hidden2 in keys(ps2.branch.bf)
if size(ps.branch.bf.hidden2.W, 2) == size(ps2.branch.bf.hidden2.W, 2)
ps2.branch.bf.hidden2.W .= ps.branch.bf.hidden2.W
elseif size(ps.branch.bf.hidden2.W, 2) < size(ps2.branch.bf.hidden2.W, 2)
ps2.branch.bf.hidden2.W .= 0.0
ps2.branch.bf.hidden2.W[1:size(ps.branch.bf.hidden2.W, 2)] .= ps.branch.bf.hidden2.W
end
end

# _map[spect] = index in readable_spec2
_map = _invmap(readable_spec2)
_mapAO = _invmapAO(specAO2)
# embed
for (idx, t) in enumerate(readable_spec)
if :hidden1 in keys(ps2.branch.bf.Pds.layer_1)
for i in keys(ps2.branch.bf.Pds)
if i in keys(ps.branch.bf.Pds)
ps2.branch.bf.Pds[i].hidden1.W[:, _map[t]] = ps.branch.bf.Pds[i].hidden1.W[:, idx]
else
@assert size(ps2.branch.bf.Pds.layer_1.hidden1.W, 2) == size(ps2.branch.bf.Pds[i].hidden1.W, 2)
ps2.branch.bf.Pds[i].hidden1.W .= ps.branch.bf.Pds[1].hidden1.W
ps2.branch.bf.Pds[i].hidden1.W[1,1] = 1.0
ps2.branch.bf.Pds[i].hidden1.W[1,2:end] .= 0.0
end
end
end
end
if :ϕnlm in keys(ps.branch.bf)
if in keys(ps.branch.bf.ϕnlm)
ps2.branch.bf.ϕnlm.ζ .= 1.0
for (idx, t) in enumerate(specAO)
ps2.branch.bf.ϕnlm.ζ[_mapAO[t]] = ps.branch.bf.ϕnlm.ζ[idx]
end
end
end

if :branch in keys(ps)
if :TK in keys(ps.branch.bf)
ps2.branch.bf.TK.W .= 0
ps2.branch.bf.TK.W[:,:,1:size(ps.branch.bf.TK.W)[3],:,1:size(ps.branch.bf.TK.W)[5]] .= ps.branch.bf.TK.W
end
end
return ps2
end
88 changes: 27 additions & 61 deletions src/vmc/multilevel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using Polynomials4ML
using Random
using ACEpsi: BackflowPooling, BFwf_lux, setupBFState, Jastrow, displayspec, mBFwf, mBFwf_sto
using ACEpsi.AtomicOrbitals: _invmap
using ACEpsi.TD: Tensor_Decomposition, No_Decomposition, Tucker
using Plots

mutable struct VMC_multilevel
Expand All @@ -31,52 +32,20 @@ end
function EmbeddingW!(ps, ps2, spec, spec2, spec1p, spec1p2, specAO, specAO2)
readable_spec = displayspec(spec, spec1p)
readable_spec2 = displayspec(spec2, spec1p2)
if :hidden1 in keys(ps.branch.bf)
@assert size(ps.branch.bf.hidden1.W, 1) == size(ps2.branch.bf.hidden1.W, 1)
@assert size(ps.branch.bf.hidden1.W, 2) size(ps2.branch.bf.hidden1.W, 2)
elseif :hidden1 in keys(ps.branch.bf.Pds.layer_1)
@assert size(ps.branch.bf.Pds.layer_1.hidden1.W, 1) == size(ps2.branch.bf.Pds.layer_1.hidden1.W, 1)
@assert size(ps.branch.bf.Pds.layer_1.hidden1.W, 2) size(ps2.branch.bf.Pds.layer_1.hidden1.W, 2)
end
@assert size(ps.branch.bf.hidden1.W, 1) == size(ps2.branch.bf.hidden1.W, 1)
@assert size(ps.branch.bf.hidden1.W, 2) size(ps2.branch.bf.hidden1.W, 2)
@assert all(t in readable_spec2 for t in readable_spec)
@assert all(t in specAO2 for t in specAO)

# set all parameters to zero
if :hidden1 in keys(ps2.branch.bf)
ps2.branch.bf.hidden1.W .= 0.0
elseif :hidden1 in keys(ps2.branch.bf.Pds.layer_1)
for i in keys(ps.branch.bf.Pds)
ps2.branch.bf.Pds[i].hidden1.W .= 0.0
end
end
if :hidden2 in keys(ps2.branch.bf)
if size(ps.branch.bf.hidden2.W, 2) == size(ps2.branch.bf.hidden2.W, 2)
ps2.branch.bf.hidden2.W .= ps.branch.bf.hidden2.W
elseif size(ps.branch.bf.hidden2.W, 2) < size(ps2.branch.bf.hidden2.W, 2)
ps2.branch.bf.hidden2.W .= 0.0
ps2.branch.bf.hidden2.W[1:size(ps.branch.bf.hidden2.W, 2)] .= ps.branch.bf.hidden2.W
end
end

ps2.branch.bf.hidden1.W .= 0.0

# _map[spect] = index in readable_spec2
_map = _invmap(readable_spec2)
_mapAO = _invmapAO(specAO2)
# embed
for (idx, t) in enumerate(readable_spec)
if :hidden1 in keys(ps2.branch.bf)
ps2.branch.bf.hidden1.W[:, _map[t]] = ps.branch.bf.hidden1.W[:, idx]
elseif :hidden1 in keys(ps2.branch.bf.Pds.layer_1)
for i in keys(ps2.branch.bf.Pds)
if i in keys(ps.branch.bf.Pds)
ps2.branch.bf.Pds[i].hidden1.W[:, _map[t]] = ps.branch.bf.Pds[i].hidden1.W[:, idx]
else
@assert size(ps2.branch.bf.Pds.layer_1.hidden1.W, 2) == size(ps2.branch.bf.Pds[i].hidden1.W, 2)
ps2.branch.bf.Pds[i].hidden1.W .= ps.branch.bf.Pds[1].hidden1.W
ps2.branch.bf.Pds[i].hidden1.W[1,1] = 1.0
ps2.branch.bf.Pds[i].hidden1.W[1,2:end] .= 0.0
end
end
end
ps2.branch.bf.hidden1.W[:, _map[t]] = ps.branch.bf.hidden1.W[:, idx]
end
if :ϕnlm in keys(ps.branch.bf)
if in keys(ps.branch.bf.ϕnlm)
Expand All @@ -87,22 +56,15 @@ function EmbeddingW!(ps, ps2, spec, spec2, spec1p, spec1p2, specAO, specAO2)
end
end

if :branch in keys(ps)
if :js in keys(ps.branch)
if :b in keys(ps.branch.js)
ps2.branch.js.b .= ps.branch.js.b
end
end
if :TK in keys(ps.branch.bf)
ps2.branch.bf.TK.W .= 0
ps2.branch.bf.TK.W[:,:,1:size(ps.branch.bf.TK.W)[3],1:size(ps.branch.bf.TK.W)[4]] .= ps.branch.bf.TK.W
end
if :TK in keys(ps.branch.bf)
ps2.branch.bf.TK.W .= 0
ps2.branch.bf.TK.W[:,:,1:size(ps.branch.bf.TK.W)[3],:,1:size(ps.branch.bf.TK.W)[5]] .= ps.branch.bf.TK.W
end
return ps2
end

function gd_GradientByVMC_multilevel(opt_vmc::VMC_multilevel, sam::MHSampler, ham::SumH, wf_list, ps_list, st_list, spec_list, spec1p_list, specAO_list;
verbose = true,
verbose = true, density = false,
accMCMC = [10, [0.45, 0.55]],
batch_size = 1)

Expand All @@ -121,9 +83,10 @@ function gd_GradientByVMC_multilevel(opt_vmc::VMC_multilevel, sam::MHSampler, ha

x0, ~, acc = sampler_restart(sam, ps, st, batch_size = batch_size)

x = reduce(vcat,reduce(vcat,x0))
display(histogram(x, xlim = (-10,10), ylim = (0,1), normalize=:pdf))

density && begin
x = reduce(vcat,reduce(vcat,x0))
display(histogram(x, xlim = (-10,10), ylim = (0,1), normalize=:pdf))
end

acc_step, acc_range = accMCMC
acc_opt = zeros(acc_step)
Expand Down Expand Up @@ -167,10 +130,13 @@ function gd_GradientByVMC_multilevel(opt_vmc::VMC_multilevel, sam::MHSampler, ha

# optimization
ps, acc, λ₀, res, σ, x0 = Optimization(opt_vmc.type, wf, ps, st, sam, ham, α, batch_size = batch_size)
if k % 10 == 0
x = reduce(vcat,reduce(vcat,x0))
display(histogram(x, xlim = (-10,10), ylim = (0,1), normalize=:pdf))
end
density && begin
if k % 10 == 0
x = reduce(vcat,reduce(vcat,x0))
display(histogram(x, xlim = (-10,10), ylim = (0,1), normalize=:pdf))
end
end

# err
verbose && @printf(" %3.d | %.5f | %.5f | %.5f | %.5f | %.3f | %.3f \n", k, λ₀, σ, res, α, acc, sam.Δt)
err_opt[l][k] = λ₀
Expand All @@ -194,12 +160,12 @@ function wf_multilevel(Nel::Int, Σ::Vector{Char}, nuclei::Vector{Nuc{T}},
bYlm::Union{RYlmBasis, CYlmBasis, CRlmBasis},
_spec::Vector{Vector{NamedTuple{(:n1, :n2, :l), Tuple{Int64, Int64, Int64}}}},
totdegree::Vector{Int},
ν::Vector{Int}) where {T}
ν::Vector{Int}, TD::Vector{TT}) where {T, TT<:Tensor_Decomposition}
level = length(ν)
wf, spec, spec1p, ps, st = [], [], [], [], []
for i = 1:level
bRnl = AtomicOrbitalsRadials(Pn, Dn, _spec[i])
_wf, _spec1, _spec1p = BFwf_lux(Nel, bRnl, bYlm, nuclei; totdeg = totdegree[i], ν = ν[i])
_wf, _spec1, _spec1p = BFwf_lux(Nel, bRnl, bYlm, nuclei, TD[i]; totdeg = totdegree[i], ν = ν[i])
_ps, _st = setupBFState(MersenneTwister(1234), _wf, Σ)
push!(wf, _wf)
push!(spec, _spec1)
Expand All @@ -217,14 +183,14 @@ function wf_multilevel(Nel::Int, Σ::Vector{Char}, nuclei::Vector{Nuc{T}},
bYlm::Union{RYlmBasis, CYlmBasis, CRlmBasis},
_spec::Vector{Vector{NamedTuple{(:n1, :n2, :l), Tuple{Int64, Int64, Int64}}}},
totdegree::Vector{Int},
ν::Vector{Int}) where {T}
ν::Vector{Int}, TD::Vector{TT}) where {T, TT<:Tensor_Decomposition}
level = length(ν)
wf, spec, spec1p, ps, st = [], [], [], [], []
for i = 1:level
ζ = ones(Float64,length(_spec[i]))
Dn = GaussianBasis(ζ)
bRnl = AtomicOrbitalsRadials(Pn, Dn, _spec[i])
_wf, _spec1, _spec1p = BFwf_lux(Nel, bRnl, bYlm, nuclei; totdeg = totdegree[i], ν = ν[i])
_wf, _spec1, _spec1p = BFwf_lux(Nel, bRnl, bYlm, nuclei, TD[i]; totdeg = totdegree[i], ν = ν[i])
_ps, _st = setupBFState(MersenneTwister(1234), _wf, Σ)
push!(wf, _wf)
push!(spec, _spec1)
Expand All @@ -242,14 +208,14 @@ function wf_multilevel(Nel::Int, Σ::Vector{Char}, nuclei::Vector{Nuc{T}},
bYlm::Union{RYlmBasis, CYlmBasis, CRlmBasis},
_spec::Vector{Vector{NamedTuple{(:n1, :n2, :l), Tuple{Int64, Int64, Int64}}}},
totdegree::Vector{Int},
ν::Vector{Int}) where {T}
ν::Vector{Int}, TD::Vector{TT}) where {T, TT<:Tensor_Decomposition}
level = length(ν)
wf, spec, spec1p, ps, st = [], [], [], [], []
for i = 1:level
ζ = ones(Float64,length(_spec[i]))
Dn = SlaterBasis(ζ)
bRnl = AtomicOrbitalsRadials(Pn, Dn, _spec[i])
_wf, _spec1, _spec1p = BFwf_lux(Nel, bRnl, bYlm, nuclei; totdeg = totdegree[i], ν = ν[i])
_wf, _spec1, _spec1p = BFwf_lux(Nel, bRnl, bYlm, nuclei, TD[i]; totdeg = totdegree[i], ν = ν[i])
_ps, _st = setupBFState(MersenneTwister(1234), _wf, Σ)
push!(wf, _wf)
push!(spec, _spec1)
Expand Down
19 changes: 12 additions & 7 deletions src/vmc/vmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,17 @@ VMC(MaxIter::Int, lr::Float64, type; tol = 1.0e-3, lr_dc = 50.0) = VMC(tol, MaxI

function gd_GradientByVMC(opt_vmc::VMC, sam::MHSampler, ham::SumH,
wf, ps, st;
ν = 1, verbose = true, accMCMC = [10, [0.45, 0.55]], batch_size = 1)
ν = 1, verbose = true, density = false, accMCMC = [10, [0.45, 0.55]], batch_size = 1)

res, λ₀, α = 1.0, 0., opt_vmc.lr
err_opt = zeros(opt_vmc.MaxIter)

x0, ~, acc = sampler_restart(sam, ps, st, batch_size = batch_size)
x = reduce(vcat,reduce(vcat,x0))
display(histogram(x, xlim = (-10,10), ylim = (0,1), normalize=:pdf))

density && begin
x = reduce(vcat,reduce(vcat,x0))
display(histogram(x, xlim = (-10,10), ylim = (0,1), normalize=:pdf))
end

acc_step, acc_range = accMCMC
acc_opt = zeros(acc_step)

Expand All @@ -42,10 +44,13 @@ function gd_GradientByVMC(opt_vmc::VMC, sam::MHSampler, ham::SumH,

# optimization
ps, acc, λ₀, res, σ, x0 = Optimization(opt_vmc.type, wf, ps, st, sam, ham, α, batch_size = batch_size)
if k % 10 == 0
x = reduce(vcat,reduce(vcat,x0))
display(histogram!(x, xlim = (-10,10), ylim = (0,1), normalize=:pdf))
density && begin
if k % 10 == 0
x = reduce(vcat,reduce(vcat,x0))
display(histogram!(x, xlim = (-10,10), ylim = (0,1), normalize=:pdf))
end
end

# err
verbose && @printf(" %3.d | %.5f | %.5f | %.5f | %.5f | %.3f | %.3f \n", k, λ₀, σ, res, α, acc, sam.Δt)
err_opt[k] = λ₀
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.

0 comments on commit b006acd

Please sign in to comment.