Skip to content

Commit

Permalink
Add project per atoms
Browse files Browse the repository at this point in the history
  • Loading branch information
Kolaru committed Feb 22, 2024
1 parent d56421b commit 3a840c8
Showing 1 changed file with 17 additions and 14 deletions.
31 changes: 17 additions & 14 deletions src/NormalModes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ using Unitful
using UnitfulAtomic

export NormalDecomposition
export project, normal_modes, normal_mode, frequencies, wave_number, reduced_masses
export project, project_per_atom
export normal_modes, normal_mode, frequencies, wave_number, reduced_masses
export normal_from_positions
export sample

Expand All @@ -28,6 +29,7 @@ end

# TODO add the number of releveant mode somewhere
struct NormalDecomposition{T}
elements::Vector{Element}
ωs::Vector{T} # Angular frequencies
M::Diagonal{T, Vector{T}} # Inverse square root of the masses
U::Matrix{T} # Orthonormal modes
Expand Down Expand Up @@ -57,7 +59,7 @@ function NormalDecomposition(hessian::AbstractMatrix, elements ; skip_modes = 6)
ωs = ωs[perm]
U = U[:, perm]

return NormalDecomposition(ωs, M, U)
return NormalDecomposition(elements, ωs, M, U)
end

function Base.show(io::IO, nm::NormalDecomposition)
Expand Down Expand Up @@ -104,6 +106,19 @@ function project(nm::NormalDecomposition, geometries::AbstractArray)
return projector * reshape(geometries, size(nm.M, 1), :)
end

function project_per_atom(nm::NormalDecomposition, geometries)
n_atoms = length(nm.elements)
xx = reshape(geometries, 3, n_atoms, :)
uu = reshape(inv(nm.M) * nm.U, 3, n_atoms, :)

projections = zeros(n_atoms, size(nm.U, 2), size(xx, 3))
for (k, x) in enumerate(eachslice(xx ; dims = 3))
projections[:, :, k] = sum(x .* uu ; dims = 1)
end

return projections
end

"""
normal_modes(nm::NormalDecomposition)
Expand Down Expand Up @@ -180,16 +195,4 @@ end
StatsBase.sample(nm::NormalDecomposition, n_samples) = sample(Random.GLOBAL_RNG, nm, n_samples)
StatsBase.sample(nm::NormalDecomposition) = sample(Random.GLOBAL_RNG, nm)

function normal_from_positions(data, elems ; skip_modes = 6)
M = mass_weight_matrix(elems)
Uz = inv(M) * to_atomic_units(reshape(data, 3*length(elems), :))
MHM = cov(Uz')
σs, U = eigen(MHM)
perm = reverse(sortperm(σs)[(skip_modes + 1):end])
σs = σs[perm]
U = U[:, perm]

return NormalDecomposition(1 ./ 2σs, M, U)
end

end

0 comments on commit 3a840c8

Please sign in to comment.