Skip to content

Commit

Permalink
enable dynamic AA
Browse files Browse the repository at this point in the history
  • Loading branch information
ACEsuit committed Jun 19, 2024
1 parent 0482883 commit 193a184
Show file tree
Hide file tree
Showing 8 changed files with 208 additions and 136 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ version = "0.0.2"
ACE1 = "e3f9bc04-086e-409a-ba78-e9769fe067bb"
ACE1x = "5cc4c08c-8782-4a30-af6d-550b302e9707"
ACEbase = "14bae519-eb20-449c-a949-9c58ed33163e"
Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChunkSplitters = "ae650224-84b6-46f8-82ea-d812ca08434e"
DynamicPolynomials = "7c1d4256-1411-5781-91ec-d7bc3513ac07"
Expand All @@ -25,14 +26,15 @@ SpheriCart = "5caf2b29-02d9-47a3-9434-5931c85ba645"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StaticPolynomials = "62e018b1-6e46-5407-a5a7-97d4fbcae734"
StrideArrays = "d1fa6d79-ef01-42a6-86c9-f7c551f8593b"
WithAlloc = "fb1aa66a-603c-4c1d-9bc4-66947c7b08dd"

[compat]
ChunkSplitters = "2"
Folds = "0.2"
LoopVectorization = "0.12"
NamedTupleTools = "0.14"
ObjectPools = "0.3"
Polynomials4ML = "0.2.7"
Polynomials4ML = "0.2.11"
julia = "1"

[extras]
Expand Down
2 changes: 1 addition & 1 deletion profile/profile_evaluate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pot = model.potential
mbpot = pot.components[2]

# convert to UFACE format
uf_ace = UltraFastACE.uface_from_ace1(mbpot; n_spl_points = 100)
uf_ace = UltraFastACE.uface_from_ace1(pot; n_spl_points = 100)

## ------------------------------------

Expand Down
4 changes: 4 additions & 0 deletions src/UltraFastACE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ import ACEbase
import ACEbase: evaluate, evaluate!,
evaluate_ed, evaluate_ed!

using Bumper, WithAlloc, StrideArrays
import WithAlloc: whatalloc

_i2z(obj, i::Integer) = obj._i2z[i]

function _z2i(obj, Z)
Expand All @@ -28,6 +31,7 @@ include("auxiliary.jl")
include("pair.jl")

include("uface.jl")
include("uface_eval.jl")

include("julip_calculator.jl")

Expand Down
49 changes: 45 additions & 4 deletions src/auxiliary.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@


#
# some auxiliary functions for UF_ACE evaluation
#
Expand Down Expand Up @@ -35,11 +37,21 @@ _len_ylm(ybasis) = (_get_L(ybasis) + 1)^2

function evaluate_ylm(ace, Rs)
TF = eltype(eltype(Rs))
Zlm = acquire!(ace.pool, :Zlm, (length(Rs), _len_ylm(ace.ybasis)), TF)
Zlm = zeros(TF, length(Rs), _len_ylm(ace.ybasis))
evaluate_ylm!(Zlm, ace, Rs)
return Zlm
end

function evaluate_ylm!(Zlm, ace, Rs)
compute!(Zlm, ace.ybasis, Rs)
return Zlm
end

function whatalloc(::typeof(evaluate_ylm!), ace, Rs)
TF = eltype(eltype(Rs))
return (TF, length(Rs), _len_ylm(ace.ybasis))
end

function evaluate_ylm_ed(ace, Rs)
TF = eltype(eltype(Rs))
Zlm = acquire!(ace.pool, :Zlm, (length(Rs), _len_ylm(ace.ybasis)), TF)
Expand All @@ -52,11 +64,15 @@ end
# ------------------------------
# element embedding



function embed_z(ace, Rs, Zs)
TF = eltype(eltype(Rs))
Ez = acquire!(ace.pool, :Ez, (length(Zs), length(ace.rbasis)), TF)
# Ez = acquire!(ace.pool, :Ez, (length(Zs), length(ace.rbasis)), TF)
Ez = zeros(TF, length(Zs), length(ace.rbasis))
return embed_z!(Ez, ace, Rs, Zs)
end


function embed_z!(Ez, ace, Rs, Zs)
fill!(Ez, 0)
for (j, z) in enumerate(Zs)
iz = _z2i(ace.rbasis, z)
Expand All @@ -66,3 +82,28 @@ function embed_z(ace, Rs, Zs)
end


function whatalloc(::typeof(embed_z!), ace, Rs, Zs)
TF = eltype(eltype(Rs))
return (TF, length(Zs), length(ace.rbasis))
end


# ------------------------------
# aadot via P4ML

using LinearAlgebra: dot

struct AADot{T, TAA}
cc::Vector{T}
aabasis::TAA
end

function (aadot::AADot)(A)
@no_escape begin
AA = @alloc(eltype(A), length(aadot.aabasis))
P4ML.evaluate!(AA, aadot.aabasis, A)
out = dot(aadot.cc, AA)
end
return out
end

10 changes: 9 additions & 1 deletion src/splines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,17 @@ end
function evaluate(ace, basis::SplineRadialsZ,
Rs::AbstractVector{<: SVector}, Zs::AbstractVector)
TF = eltype(eltype(Rs))
Rn = acquire!(ace.pool, :Rn, (length(Rs), length(basis)), TF)
Rn = zeros(TF, length(Rs), length(basis))
evaluate!(Rn, basis, Rs, Zs)
return Rn
end

function whatalloc(::typeof(ACEbase.evaluate!),
basis::SplineRadialsZ, Rs, Zs)
TF = eltype(eltype(Rs))
return (TF, length(Rs), length(basis))
end


function evaluate!(out, basis::SplineRadialsZ, Rs, Zs)
nX = length(Rs)
Expand All @@ -49,6 +55,8 @@ function evaluate!(out, basis::SplineRadialsZ, Rs, Zs)
end




function evaluate_ed(ace, rbasis, Rs, Zs)
TF = eltype(eltype(Rs))
Rn = acquire!(ace.pool, :Rn, (length(Rs), length(rbasis)), TF)
Expand Down
146 changes: 18 additions & 128 deletions src/uface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ struct UFACE_inner{TR, TY, TA, TAA}
abasis::TA
aadot::TAA
# ---------- admin and meta-data
pool::TSafe{ArrayPool{FlexArrayCache}}
# pool::TSafe{ArrayPool{FlexArrayCache}}
meta::Dict
end

UFACE_inner(rbasis, ybasis, abasis, aadot) =
UFACE_inner(rbasis, ybasis, abasis, aadot,
TSafe(ArrayPool(FlexArrayCache)),
# TSafe(ArrayPool(FlexArrayCache)),
Dict())

struct UFACE{NZ, INNER, PAIR}
Expand All @@ -30,139 +30,17 @@ struct UFACE{NZ, INNER, PAIR}
pairpot::PAIR
E0s::Dict{Int, Float64}
# ----------
pool::TSafe{ArrayPool{FlexArrayCache}}
# pool::TSafe{ArrayPool{FlexArrayCache}}
meta::Dict
end

UFACE(_i2z, ace_inner, pairpot, E0s) =
UFACE(_i2z, ace_inner, pairpot, E0s,
TSafe(ArrayPool(FlexArrayCache)),
# TSafe(ArrayPool(FlexArrayCache)),
Dict())



function ACEbase.evaluate(ace::UFACE, Rs, Zs, zi)
i_zi = _z2i(ace, zi)
ace_inner = ace.ace_inner[i_zi]
Ei = ( evaluate(ace_inner, Rs, Zs) +
evaluate(ace.pairpot, Rs, Zs, zi) +
ace.E0s[zi] )
end

# --------------------------------------------------------
# UF_ACE evaluation code.

function ACEbase.evaluate(ace::UFACE_inner, Rs, Zs)
TF = eltype(eltype(Rs))
rbasis = ace.rbasis
NZ = length(rbasis._i2z)

# embeddings
# element embedding
Ez = embed_z(ace, Rs, Zs)
# radial embedding
Rn = evaluate(ace, rbasis, Rs, Zs)
# angular embedding
Zlm = evaluate_ylm(ace, Rs)

# pooling
A = ace.abasis((unwrap(Ez), unwrap(Rn), unwrap(Zlm)))

# n correlations
φ = ace.aadot(A)

# release the borrowed arrays
release!(Zlm)
release!(Rn)
release!(Ez)
release!(A)

return φ
end

function evaluate_ed(ace::UFACE, Rs, Zs, z0)
∇φ = acquire!(ace.pool, :out_dEs, (length(Rs),), eltype(Rs))
return evaluate_ed!(∇φ, ace, Rs, Zs, z0)
end

function evaluate_ed!(∇φ, ace::UFACE, Rs, Zs, z0)
i_z0 = _z2i(ace, z0)
ace_inner = ace.ace_inner[i_z0]
φ, _ = evaluate_ed!(∇φ, ace_inner, Rs, Zs)
add_evaluate_d!(∇φ, ace.pairpot, Rs, Zs, z0)
φ += ace.E0s[z0] + evaluate(ace.pairpot, Rs, Zs, z0)
return φ, ∇φ
end


function ACEbase.evaluate_ed!(∇φ, ace::UFACE_inner, Rs, Zs)
TF = eltype(eltype(Rs))
rbasis = ace.rbasis
NZ = length(rbasis._i2z)

# embeddings
# element embedding (there is no gradient)
Ez = embed_z(ace, Rs, Zs)

# radial embedding
Rn, dRn = evaluate_ed(ace, rbasis, Rs, Zs)

# angular embedding
Zlm, dZlm = evaluate_ylm_ed(ace, Rs)

# pooling
A = ace.abasis((Ez, Rn, Zlm))

# n correlations - compute with gradient, do it in-place
∂φ_∂A = acquire!(ace.pool, :∂A, size(A), TF)
φ = evaluate_and_gradient!(∂φ_∂A, ace.aadot, A)

# backprop through A => this part could be done more nicely I think
∂φ_∂Ez = BlackHole(TF)
# ∂φ_∂Ez = zeros(TF, size(Ez))
∂φ_∂Rn = acquire!(ace.pool, :∂Rn, size(Rn), TF)
∂φ_∂Zlm = acquire!(ace.pool, :∂Zlm, size(Zlm), TF)
fill!(∂φ_∂Rn, zero(TF))
fill!(∂φ_∂Zlm, zero(TF))
P4ML._pullback_evaluate!((∂φ_∂Ez, unwrap(∂φ_∂Rn), unwrap(∂φ_∂Zlm)),
unwrap(∂φ_∂A),
ace.abasis,
(unwrap(Ez), unwrap(Rn), unwrap(Zlm));
sizecheck=false)

# backprop through the embeddings
# depending on whether there is a bottleneck here, this can be
# potentially implemented more efficiently without needing writing/reading
# (to be investigated where the bottleneck is)

# we just ignore Ez (hence the black hole)

# backprop through Rn
# We already computed the gradients in the forward pass
fill!(∇φ, zero(SVector{3, TF}))
@inbounds for n = 1:size(Rn, 2)
@simd ivdep for j = 1:length(Rs)
∇φ[j] += ∂φ_∂Rn[j, n] * dRn[j, n]
end
end

# ... and Ylm
@inbounds for i_lm = 1:size(Zlm, 2)
@simd ivdep for j = 1:length(Rs)
∇φ[j] += ∂φ_∂Zlm[j, i_lm] * dZlm[j, i_lm]
end
end

# release the borrowed arrays
release!(Zlm); release!(dZlm)
release!(Rn); release!(dRn)
release!(Ez)
release!(A)
release!(∂φ_∂Rn); release!(∂φ_∂Zlm); release!(∂φ_∂A)

return φ, ∇φ
end


# ------------------------------------------------------
# transformation code : ACE1 -> UF_ACE models
Expand Down Expand Up @@ -196,12 +74,18 @@ end



function uface_from_ace1_inner(mbpot, iz; n_spl_points = 100)
function uface_from_ace1_inner(mbpot, iz;
n_spl_points = 100,
aa_static = :auto)
b1p = mbpot.pibasis.basis1p
zlist = b1p.zlist.list
z0 = zlist[iz]
t_zlist = tuple(zlist...)

if aa_static == :auto
aa_static = (length(mbpot.pibasis.inner[1]) < 1_200)
end

# radial embedding
Rn_basis = mbpot.pibasis.basis1p.J
LEN_Rn = length(Rn_basis.J)
Expand Down Expand Up @@ -258,7 +142,13 @@ function uface_from_ace1_inner(mbpot, iz; n_spl_points = 100)

# AA_basis = P4ML.SparseSymmProd(spec_AA_inds)
c_r_iz = AA_transform[:T]' * mbpot.coeffs[iz]
aadot = generate_AA_dot(spec_AA_inds, c_r_iz)

if aa_static
aadot = generate_AA_dot(spec_AA_inds, c_r_iz)
else
aabasis = P4ML.SparseSymmProd(spec_AA_inds)
aadot = AADot(c_r_iz, aabasis)
end

return UFACE_inner(rbasis_new, rYlm_basis_sc, A_basis, aadot)
end
Expand Down
Loading

0 comments on commit 193a184

Please sign in to comment.