Skip to content

Commit

Permalink
Allow to specify RNG
Browse files Browse the repository at this point in the history
  • Loading branch information
Kolaru committed Jan 22, 2024
1 parent fb15d1e commit 5c1d591
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PeriodicTable = "7b2266bf-644c-5ea3-82d8-af4bbd25a884"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
UnitfulAtomic = "a7773ee8-282e-5fa2-be4e-bd808c38a91a"
Expand Down
14 changes: 9 additions & 5 deletions src/NormalModes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using Distributions
using LinearAlgebra
using PeriodicTable
using Printf
using Random
using StatsBase
using Unitful
using UnitfulAtomic
Expand Down Expand Up @@ -159,23 +160,26 @@ Perform ground state Wigner sampling according to the normal decomposition.
Return the deviation from the average geometry and the deviation from zero
momentum.
"""
function StatsBase.sample(nm::NormalDecomposition, n_samples)
function StatsBase.sample(rng::AbstractRNG, nm::NormalDecomposition, n_samples)
hbar = 1 # Atomic units
Δx_dist = MvNormal(Diagonal(1/2 * (hbar ./ nm.ωs)))
Δp_dist = MvNormal(Diagonal(1/2 * hbar * nm.ωs))
MU = nm.M * nm.U

Δx = MU * rand(Δx_dist, n_samples)
Δp = inv(nm.M)^2 * MU * rand(Δp_dist, n_samples)
Δx = MU * rand(rng, Δx_dist, n_samples)
Δp = inv(nm.M)^2 * MU * rand(rng, Δp_dist, n_samples)

return Δx * aunit(u"m"), Δp * aunit(u"kg*m/s")
end

function StatsBase.sample(nm::NormalDecomposition)
geometries, momenta = sample(nm, 1)
function StatsBase.sample(rng::AbstractRNG, nm::NormalDecomposition)
geometries, momenta = sample(rng, nm, 1)
return geometries[:, 1], momenta[:, 1]
end

StatsBase.sample(nm::NormalDecomposition, n_samples) = sample(Random.GLOBAL_RNG, nm, n_samples)
StatsBase.sample(nm::NormalDecomposition) = sample(Random.GLOBAL_RNG, nm)

function normal_from_positions(data, elems ; skip_modes = 6)
M = mass_weight_matrix(elems)
Uz = inv(M) * to_atomic_units(reshape(data, 3*length(elems), :))
Expand Down

0 comments on commit 5c1d591

Please sign in to comment.