From 65f633b632030de2062de97a86761fa27eb72122 Mon Sep 17 00:00:00 2001 From: Ilian Pihlajamaa Date: Tue, 10 Sep 2024 11:20:39 +0200 Subject: [PATCH] bugfix --- Project.toml | 3 ++- src/Solvers.jl | 6 ++++-- src/Solvers/DensityRamp.jl | 19 +++++++++++++++++-- src/Solvers/NgIteration.jl | 17 ++++++++++++----- 4 files changed, 35 insertions(+), 10 deletions(-) diff --git a/Project.toml b/Project.toml index d3312a1..fbedaa4 100644 --- a/Project.toml +++ b/Project.toml @@ -14,13 +14,14 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228" [compat] -FunctionZeros = "0.3" Bessels = "0.2" Dierckx = "0.5" FFTW = "1" ForwardDiff = "0.10" +FunctionZeros = "0.3" Hankel = "0.5" Roots = "2.0" SpecialFunctions = "2" diff --git a/src/Solvers.jl b/src/Solvers.jl index abd21c2..40ee5a4 100644 --- a/src/Solvers.jl +++ b/src/Solvers.jl @@ -156,13 +156,15 @@ function DensityRamp(method, densities::AbstractVector{T}; verbose=true) where T return DensityRamp(method, densities, verbose) elseif T <: AbstractVector @assert issorted(sum.(densities)) + Ns = length(densities[1]) + @assert allequal(length.(densities)) densities = Diagonal.(SVector{Ns}.(densities)) return DensityRamp(method, densities, verbose) - elseif T <: AbstractMatrix + elseif T <: Diagonal @assert issorted(sum.(densities)) - @assert all(isdiag.(densities)) return DensityRamp(method, densities, verbose) end + error("Invalid type for densities.") end """ diff --git a/src/Solvers/DensityRamp.jl b/src/Solvers/DensityRamp.jl index 69c6d3c..66b7b61 100644 --- a/src/Solvers/DensityRamp.jl +++ b/src/Solvers/DensityRamp.jl @@ -1,3 +1,16 @@ +function recast_γ(γ::Array{T, 3}) where {T} + Ns = size(γ, 2) + Nk = size(γ, 1) + γ2 = zeros(SMatrix{Ns, Ns, T, Ns^2}, Nk) + for i = 1:Nk + γ2[i] = γ[i, :, :] + end + return γ2 +end + +recast_γ(γ::Array{T, 1}) where {T} = γ + + function solve(system::SimpleLiquid, closure::Closure, method::DensityRamp) densities = method.densities ρtarget = system.ρ @@ -14,7 +27,8 @@ function solve(system::SimpleLiquid, closure::Closure, method::DensityRamp) println("\nSolving the system at ρ = $(densities[i]).\n") end system.ρ = densities[i] - sol = solve(system, closure, method.method, init=γ_old) + γ_old2 = recast_γ(γ_old) + sol = solve(system, closure, method.method, init=γ_old2) push!(sols, sol) @. γ_old = sol.gr - one(eltype(sol.gr)) - sol.cr end @@ -22,7 +36,8 @@ function solve(system::SimpleLiquid, closure::Closure, method::DensityRamp) println("\nSolving the system at ρ = $(ρtarget).\n") end system.ρ = ρtarget - sol = solve(system, closure, method.method, init=γ_old) + γ_old2 = recast_γ(γ_old) + sol = solve(system, closure, method.method, init=γ_old2) push!(sols, sol) return sols end diff --git a/src/Solvers/NgIteration.jl b/src/Solvers/NgIteration.jl index f498c3b..b3cdf8d 100644 --- a/src/Solvers/NgIteration.jl +++ b/src/Solvers/NgIteration.jl @@ -122,21 +122,28 @@ function solve(system::SimpleLiquid{dims, species, T1, T2, P}, closure::Closure, Ĉ = copy(mayer_f) Γ_new = copy(mayer_f) + gn = initialize_vector_of_vectors(TT, N_stages+1, Ns*Ns*Nr) # first element is g_n, second is g_{n-1} etc fn = initialize_vector_of_vectors(TT, N_stages+1, Ns*Ns*Nr) dn = initialize_vector_of_vectors(TT, N_stages+1, Ns*Ns*Nr) d0n = initialize_vector_of_vectors(TT, N_stages, Ns*Ns*Nr) # first element is d01 second is d02 etc - if !(isnothing(init)) - fn[end] .= init.*r - else - fn[end] .= zero(eltype(eltype(fn))) - end + # the elements of fn etc are vectors of length Ns*Ns*Nr + + + A = zeros(TT, N_stages, N_stages) b = zeros(TT, N_stages) Γ_new_full = reshape(reinterpret(reshape, TT, Γ_new), Ns*Ns*Nr) # for going back and forth between vec{float} and vec{Smat} gn_red = reinterpret_vector_of_vectors(gn, T, Ns*Ns, Nr)#[reinterpret(reshape, T, reshape(gn[i], (Ns*Ns, Nr))) for i in eachindex(gn)] fn_red = reinterpret_vector_of_vectors(fn, T, Ns*Ns, Nr)#[reinterpret(reshape, T, reshape(fn[i], (Ns*Ns, Nr))) for i in eachindex(fn)] + if !(isnothing(init)) + fn_red[end] .= init.*r + else + @show zero(eltype(eltype(fn_red))) + fn_red[end] .= (zero(eltype(eltype(fn_red))), ) + end + max_iterations = method.max_iterations tolerance = method.tolerance