Skip to content

Commit

Permalink
Merge #152
Browse files Browse the repository at this point in the history
152: Rename "Tableau"s to "AlgorithmName"s r=charleskawczynski a=dennisYatunin



Co-authored-by: Dennis Yatunin <[email protected]>
Co-authored-by: Charles Kawczynski <[email protected]>
  • Loading branch information
3 people authored Jan 11, 2023
2 parents 5252f24 + a4fcddb commit 32551d3
Showing 15 changed files with 252 additions and 249 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ClimaTimeSteppers"
uuid = "595c0a79-7f3d-439a-bc5a-b232dc3bde79"
authors = ["Climate Modeling Alliance"]
version = "0.6.0"
version = "0.7.0"

[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
13 changes: 5 additions & 8 deletions docs/src/api/algorithms.md
Original file line number Diff line number Diff line change
@@ -8,26 +8,23 @@ CurrentModule = ClimaTimeSteppers

```@docs
ForwardEulerODEFunction
AbstractAlgorithmConstraint
Unconstrained
SSPConstrained
IMEXAlgorithm
IMEXTableau
```

## IMEX SSP methods

```@docs
IMEXSSPRKAlgorithm
SSP433
SSP222
SSP332
SSP333
SSP322
```

## IMEX ARK methods

```@docs
IMEXARKAlgorithm
IMEXARKTableau
```

## Low-Storage Runge--Kutta (LSRK) methods

Low-storage Runger--Kutta methods reduce the number stages that need to be stored.
2 changes: 1 addition & 1 deletion docs/src/dev/report_gen.jl
Original file line number Diff line number Diff line change
@@ -17,7 +17,7 @@ include(joinpath(cts_dir, "test", "problems.jl"))
tab3 = (ARS233, ARS343, ARS443, IMKG342a, IMKG343a, DBM453)
tabs = [tab1..., tab2..., tab3...]
tabs = map(t -> t(), tabs)
get_alg = (tab, test_case) -> IMEXARKAlgorithm(tab, NewtonsMethod(; max_iters = test_case.linear_implicit ? 1 : 2))
get_alg = (tab, test_case) -> IMEXAlgorithm(tab, NewtonsMethod(; max_iters = test_case.linear_implicit ? 1 : 2))
test_algs("IMEX ARK", get_alg, tabs, ark_analytic_nonlin_test_cts(Float64), 400)
test_algs("IMEX ARK", get_alg, tabs, ark_analytic_sys_test_cts(Float64), 60)
test_algs("IMEX ARK", get_alg, tabs, ark_analytic_test_cts(Float64), 16000; super_convergence = ARS121())
2 changes: 1 addition & 1 deletion docs/src/dev/types.md
Original file line number Diff line number Diff line change
@@ -19,5 +19,5 @@ import AbstractTrees as AT
import InteractiveUtils as IU
import ClimaTimeSteppers as CTS
AT.children(x::Type) = IU.subtypes(x)
AT.print_tree(CTS.AbstractTableau)
AT.print_tree(CTS.AbstractAlgorithmName)
```
2 changes: 1 addition & 1 deletion perf/benchmark.jl
Original file line number Diff line number Diff line change
@@ -5,7 +5,7 @@ using BenchmarkTools, DiffEqBase
include(joinpath(pkgdir(CTS), "test", "problems.jl"))

function main()
algorithm = CTS.IMEXARKAlgorithm(CTS.ARS343(), CTS.NewtonsMethod(; max_iters = 2))
algorithm = CTS.IMEXAlgorithm(CTS.ARS343(), CTS.NewtonsMethod(; max_iters = 2))
dt = 0.01
for problem in (split_linear_prob_wfact_split(), split_linear_prob_wfact_split_fe())
integrator = DiffEqBase.init(problem, algorithm; dt)
2 changes: 1 addition & 1 deletion perf/flame.jl
Original file line number Diff line number Diff line change
@@ -30,7 +30,7 @@ elseif problem_str == "fe"
else
error("Bad option")
end
algorithm = CTS.IMEXARKAlgorithm(ARS343(), NewtonsMethod(; max_iters = 2))
algorithm = CTS.IMEXAlgorithm(ARS343(), NewtonsMethod(; max_iters = 2))
dt = 0.01
integrator = DiffEqBase.init(prob, algorithm; dt)
cache = CTS.init_cache(prob, algorithm)
2 changes: 1 addition & 1 deletion perf/jet.jl
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@ end
cts = joinpath(dirname(@__DIR__));
include(joinpath(cts, "test", "problems.jl"))
function config_integrators(problem)
algorithm = CTS.IMEXARKAlgorithm(ARS343(), NewtonsMethod(; max_iters = 2))
algorithm = CTS.IMEXAlgorithm(ARS343(), NewtonsMethod(; max_iters = 2))
dt = 0.01
integrator = DiffEqBase.init(problem, algorithm; dt)
integrator.cache = CTS.init_cache(problem, algorithm)
38 changes: 29 additions & 9 deletions src/ClimaTimeSteppers.jl
Original file line number Diff line number Diff line change
@@ -51,7 +51,7 @@ using LinearOperators
using StaticArrays
using CUDA

export AbstractIMEXARKTableau
export AbstractAlgorithmName, AbstractAlgorithmConstraint, Unconstrained, SSPConstrained

array_device(::Union{Array, SArray, MArray}) = CPU()
array_device(::CuArray) = CUDADevice()
@@ -66,18 +66,38 @@ include("operators.jl")

abstract type DistributedODEAlgorithm <: DiffEqBase.AbstractODEAlgorithm end

abstract type AbstractIMEXARKAlgorithm <: DistributedODEAlgorithm end
abstract type AbstractAlgorithmName end

abstract type AbstractTableau end
abstract type AbstractIMEXARKTableau <: AbstractTableau end
abstract type AbstractIMEXSSPARKTableau <: AbstractTableau end
"""
AbstractAlgorithmConstraint
A mechanism for restricting which operations can be performed by an algorithm
for solving ODEs.
For example, an unconstrained algorithm might compute a Runge-Kutta stage by
taking linear combinations of tendencies; i.e., by adding quantities of the form
`dt * tendency(state)`. On the other hand, a "strong stability preserving"
algorithm can only take linear combinations of "incremented states"; i.e., it
only adds quantities of the form `state + dt * coefficient * tendency(state)`.
"""
abstract type AbstractAlgorithmConstraint end

"""
Unconstrained
Indicates that an algorithm may perform any supported operations.
"""
struct Unconstrained <: AbstractAlgorithmConstraint end

"""
tableau(::DistributedODEAlgorithm)
SSPConstrained
Returns the tableau for a particular algorithm.
Indicates that an algorithm must be "strong stability preserving", which makes
it easier to guarantee that the algorithm will preserve monotonicity properties
satisfied by the initial state. For example, this ensures that the algorithm
will be able to use limiters in a mathematically consistent way.
"""
function tableau end
struct SSPConstrained <: AbstractAlgorithmConstraint end

SciMLBase.allowscomplex(alg::DistributedODEAlgorithm) = true
include("integrators.jl")
@@ -92,7 +112,7 @@ n_stages_ntuple(::Type{<:NTuple{Nstages}}) where {Nstages} = Nstages
n_stages_ntuple(::Type{<:SVector{Nstages}}) where {Nstages} = Nstages

# Include concrete implementations
include("solvers/imex_ark_tableaus.jl")
include("solvers/imex_tableaus.jl")
include("solvers/imex_ark.jl")
include("solvers/imex_ssp.jl")
include("solvers/multirate.jl")
45 changes: 13 additions & 32 deletions src/solvers/imex_ark.jl
Original file line number Diff line number Diff line change
@@ -1,31 +1,16 @@
export IMEXARKAlgorithm

"""
IMEXARKAlgorithm(
tabname::AbstractTableau,
newtons_method
) <: DistributedODEAlgorithm
A generic implementation of an IMEX ARK algorithm that can handle arbitrary
Butcher tableaus and problems specified using either `ForwardEulerODEFunction`s
or regular `ODEFunction`s.
"""
struct IMEXARKAlgorithm{T <: IMEXARKTableau, NM} <: DistributedODEAlgorithm
tab::T
newtons_method::NM
function IMEXARKAlgorithm(tabname::AbstractTableau, newtons_method)
tab = tableau(tabname)
T = typeof(tab)
new{T, typeof(newtons_method)}(tab, newtons_method)
end
end

has_jac(T_imp!) =
hasfield(typeof(T_imp!), :Wfact) &&
hasfield(typeof(T_imp!), :jac_prototype) &&
!isnothing(T_imp!.Wfact) &&
!isnothing(T_imp!.jac_prototype)

sdirk_error(name) = error("$(isnothing(name) ? "The given IMEXTableau" : name) \
has implicit stages with distinct coefficients (it \
is not SDIRK), and an update is required whenever a \
stage has a different coefficient from the previous \
stage. Do not update on the NewTimeStep signal when \
using $(isnothing(name) ? "this tableau" : name).")

struct IMEXARKCache{SCU, SCE, SCI, T, Γ, NMC}
U::SCU # sparse container of length s
T_lim::SCE # sparse container of length s
@@ -36,11 +21,11 @@ struct IMEXARKCache{SCU, SCE, SCI, T, Γ, NMC}
newtons_method_cache::NMC
end

function init_cache(prob::DiffEqBase.AbstractODEProblem, alg::IMEXARKAlgorithm; kwargs...)
function init_cache(prob::DiffEqBase.AbstractODEProblem, alg::IMEXAlgorithm{Unconstrained}; kwargs...)
(; u0, f) = prob
(; T_imp!) = f
(; tab, newtons_method) = alg
(; a_exp, b_exp, a_imp, b_imp) = tab
(; tableau, newtons_method) = alg
(; a_exp, b_exp, a_imp, b_imp) = tableau
s = length(b_exp)
inds = ntuple(i -> i, s)
inds_T_exp = filter(i -> !all(iszero, a_exp[:, i]) || !iszero(b_exp[i]), inds)
@@ -61,8 +46,8 @@ function step_u!(integrator, cache::IMEXARKCache)
(; u, p, t, dt, sol, alg) = integrator
(; f) = sol.prob
(; T_lim!, T_exp!, T_imp!, lim!, dss!, stage_callback!) = f
(; tab, newtons_method) = alg
(; a_exp, b_exp, a_imp, b_imp, c_exp, c_imp) = tab
(; name, tableau, newtons_method) = alg
(; a_exp, b_exp, a_imp, b_imp, c_exp, c_imp) = tableau
(; U, T_lim, T_exp, T_imp, temp, γ, newtons_method_cache) = cache
s = length(b_exp)

@@ -71,11 +56,7 @@ function step_u!(integrator, cache::IMEXARKCache)
newtons_method,
newtons_method_cache,
NewTimeStep(t),
jacobian ->
isnothing(γ) ? error("The tableau does not specify a unique value of γ for the \
duration of each time step; do not update based on the \
NewTimeStep signal when using this tableau.") :
T_imp!.Wfact(jacobian, u, p, dt * γ, t),
jacobian -> isnothing(γ) ? sdirk_error(name) : T_imp!.Wfact(jacobian, u, p, dt * γ, t),
)
end

96 changes: 28 additions & 68 deletions src/solvers/imex_ssp.jl
Original file line number Diff line number Diff line change
@@ -1,75 +1,21 @@
export IMEXSSPRKAlgorithm

#=
U[i] = (1 - β[i-1]) * u + β[i-1] * (U[i-1] + dt * T_exp(U[i-1])) for i > 1 ==>
U[2] = (1 - β[1]) * u + β[1] * (U[1] + dt * T_exp(U[1])) =
u + dt * β[1] * T_exp(U[1])
U[3] = (1 - β[2]) * u + β[2] * (U[2] + dt * T_exp(U[2])) =
u +
dt * β[1] * β[2] * T_exp(U[1]) +
dt * β[2] * T_exp(U[2])
U[4] = (1 - β[3]) * u + β[3] * (U[3] + dt * T_exp(U[3])) =
u +
dt * β[1] * β[2] * Β[3] * T_exp(U[1]) +
dt * β[2] * Β[3] * T_exp(U[2]) +
dt * Β[3] * T_exp(U[3])
...
U[i] = u + ∑_{j = 1}^{i - 1} dt * a_exp[i, j] * T_exp(U[j]) ==>
a_exp = [
0 0 0 …
β[1] 0 0 …
β[1] * β[2] β[2] 0 …
β[1] * β[2] * β[3] β[2] * β[3] β[3] …
⋮ ⋮ ⋮ ⋱
] ==>
a_exp[i+1:s+1, i] = cumprod(β[i:s])
=#

"""
IMEXSSPRKAlgorithm(
tabname::AbstractTableau,
newtons_method
) <: DistributedODEAlgorithm
A generic implementation of an IMEX SSP RK algorithm that can handle arbitrary
Butcher tableaus.
"""
struct IMEXSSPRKAlgorithm{B, T <: IMEXARKTableau, NM} <: DistributedODEAlgorithm
β::B
tab::T
newtons_method::NM
end
function IMEXSSPRKAlgorithm(tabname::AbstractTableau, newtons_method)
tab = tableau(tabname)
(; a_exp, b_exp) = tab
â_exp = vcat(a_exp, b_exp')
β = diag(â_exp, -1)
for i in 1:length(β)
if â_exp[(i + 1):end, i] != cumprod(β[i:end])
error("Tableau does not satisfy requirements for an SSP RK method")
end
end
IMEXSSPRKAlgorithm(β, tab, newtons_method)
end

struct IMEXSSPRKCache{U, SCI, Γ, NMC}
struct IMEXSSPRKCache{U, SCI, B, Γ, NMC}
U::U
U_exp::U
U_lim::U
T_lim::U
T_exp::U
T_imp::SCI # sparse container of length s
temp::U
β::B
γ::Γ
newtons_method_cache::NMC
end

function init_cache(prob::DiffEqBase.AbstractODEProblem, alg::IMEXSSPRKAlgorithm; kwargs...)
function init_cache(prob::DiffEqBase.AbstractODEProblem, alg::IMEXAlgorithm{SSPConstrained}; kwargs...)
(; u0, f) = prob
(; T_imp!) = f
(; tab, newtons_method) = alg
(; a_exp, b_exp, a_imp, b_imp) = tab
(; tableau, newtons_method) = alg
(; a_exp, b_exp, a_imp, b_imp) = tableau
s = length(b_exp)
inds = ntuple(i -> i, s)
inds_T_imp = filter(i -> !all(iszero, a_imp[:, i]) || !iszero(b_imp[i]), inds)
@@ -80,32 +26,46 @@ function init_cache(prob::DiffEqBase.AbstractODEProblem, alg::IMEXSSPRKAlgorithm
U_lim = similar(u0)
T_imp = SparseContainer(map(i -> similar(u0), collect(1:length(inds_T_imp))), inds_T_imp)
temp = similar(u0)
â_exp = vcat(a_exp, b_exp')
β = diag(â_exp, -1)
for i in 1:length(β)
if â_exp[(i + 1):end, i] != cumprod(β[i:end])
error("The SSPConstrained IMEXAlgorithm currently only supports an \
IMEXTableau that specifies a \"low-storage\" IMEX SSPRK \
algorithm, where the canonical Shu-Osher representation of \
the i-th explicit stage for i > 1 must have the form U[i] = \
(1 - β[i-1]) * u + β[i-1] * (U[i-1] + dt * T_exp(U[i-1])). \
So, it must be possible to express vcat(a_exp, b_exp') as\n \
0 0 0 …\n \
β[1] 0 0 …\n \
β[1] * β[2] β[2] 0 …\n \
β[1] * β[2] * β[3] β[2] * β[3] β[3] …\n \
⋮ ⋮ ⋮ ⋱\n \
The given IMEXTableau does not satisfy this property.")
end
end
γs = unique(filter(!iszero, diag(a_imp)))
γ = length(γs) == 1 ? γs[1] : nothing # TODO: This could just be a constant.
jac_prototype = has_jac(T_imp!) ? T_imp!.jac_prototype : nothing
newtons_method_cache = isnothing(T_imp!) ? nothing : allocate_cache(newtons_method, u0, jac_prototype)
return IMEXSSPRKCache(U, U_exp, U_lim, T_lim, T_exp, T_imp, temp, γ, newtons_method_cache)
return IMEXSSPRKCache(U, U_exp, U_lim, T_lim, T_exp, T_imp, temp, β, γ, newtons_method_cache)
end

function step_u!(integrator, cache::IMEXSSPRKCache)
(; u, p, t, dt, sol, alg) = integrator
(; f) = sol.prob
(; T_lim!, T_exp!, T_imp!, lim!, dss!, stage_callback!) = f
(; β, tab, newtons_method) = alg
(; a_imp, b_imp, c_exp, c_imp) = tab
(; U, U_lim, U_exp, T_lim, T_exp, T_imp, temp, γ, newtons_method_cache) = cache
(; name, tableau, newtons_method) = alg
(; a_imp, b_imp, c_exp, c_imp) = tableau
(; U, U_lim, U_exp, T_lim, T_exp, T_imp, temp, β, γ, newtons_method_cache) = cache
s = length(b_imp)

if !isnothing(T_imp!)
update!(
newtons_method,
newtons_method_cache,
NewTimeStep(t),
jacobian ->
isnothing(γ) ? error("The tableau does not specify a unique value of γ for the \
duration of each time step; do not update based on the \
NewTimeStep signal when using this tableau.") :
T_imp!.Wfact(jacobian, u, p, dt * γ, t),
jacobian -> isnothing(γ) ? sdirk_error(name) : T_imp!.Wfact(jacobian, u, p, dt * γ, t),
)
end

Loading

2 comments on commit 32551d3

@charleskawczynski
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/75499

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.7.0 -m "<description of version>" 32551d3c8f56b95a41a16005cc34a73a88e5a380
git push origin v0.7.0

Please sign in to comment.