Skip to content

Commit

Permalink
Time Discretisation Improvements (#560)
Browse files Browse the repository at this point in the history
  • Loading branch information
atb1995 authored Oct 29, 2024
1 parent 2e376b3 commit 2a929f8
Show file tree
Hide file tree
Showing 8 changed files with 153 additions and 63 deletions.
1 change: 1 addition & 0 deletions gusto/core/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def __call__(self, target, value=None):

# labels for terms in the equations
time_derivative = Label("time_derivative")
nonlinear_time_derivative = Label("nonlinear_time_derivative")
transport = Label("transport",
validator=lambda value: type(value) == TransportEquationType)
diffusion = Label("diffusion")
Expand Down
7 changes: 4 additions & 3 deletions gusto/equations/prognostic_equations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
replace_subject, replace_trial_function
)
from gusto.core import PrescribedFields
from gusto.core.labels import time_derivative, prognostic, linearisation, mass_weighted
from gusto.core.labels import (nonlinear_time_derivative, time_derivative,
prognostic, linearisation, mass_weighted)
from gusto.equations.common_forms import (
advection_form, continuity_form, tracer_conservative_form
)
Expand Down Expand Up @@ -163,8 +164,8 @@ def generate_mass_terms(self):
ref_density_idx = self.field_names.index(self.active_tracers[j].density_name)
ref_density = split(self.X)[ref_density_idx]
q = prog*ref_density
mass_weighted_form = time_derivative(subject(prognostic(inner(q, test)*dx,
field_name), self.X))
mass_weighted_form = nonlinear_time_derivative(time_derivative(
subject(prognostic(inner(q, test)*dx, field_name), self.X)))

mass = mass_weighted(standard_mass_form, mass_weighted_form)
if i == 0:
Expand Down
19 changes: 10 additions & 9 deletions gusto/time_discretisation/explicit_runge_kutta.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,6 @@ def solver(self):
return super().solver

elif self.rk_formulation == RungeKuttaFormulation.predictor:
# In this case, don't set snes_type to ksp only, as we do want the
# outer Newton iteration. This is achieved by not calling the
# "super" method, in which the default snes_type is set to ksp_only
solver_list = []

for stage in range(self.nStages):
Expand All @@ -180,9 +177,6 @@ def solver(self):
return solver_list

elif self.rk_formulation == RungeKuttaFormulation.linear:
# In this case, don't set snes_type to ksp only, as we do want the
# outer Newton iteration. This is achieved by not calling the
# "super" method, in which the default snes_type is set to ksp_only
problem = NonlinearVariationalProblem(
self.lhs - self.rhs[0], self.x1, bcs=self.bcs
)
Expand Down Expand Up @@ -358,6 +352,10 @@ def solve_stage(self, x0, stage):
evaluate(self.x1, self.dt)
if self.limiter is not None:
self.limiter.apply(self.x1)

# Set initial guess for solver
if stage > 0:
self.x_out.assign(self.k[stage-1])
self.solver.solve()

self.k[stage].assign(self.x_out)
Expand All @@ -376,8 +374,8 @@ def solve_stage(self, x0, stage):
if stage == 0:
self.field_i[0].assign(x0)

# Use x0 as a first guess (otherwise may not converge)
self.field_i[stage+1].assign(x0)
# Use previous stage value as a first guess (otherwise may not converge)
self.field_i[stage+1].assign(self.field_i[stage])

# Update field_i for physics / limiters
for evaluate in self.evaluate_source:
Expand Down Expand Up @@ -423,6 +421,8 @@ def solve_stage(self, x0, stage):
if self.limiter is not None:
self.limiter.apply(self.field_rhs)

# Use previous stage value as a first guess (otherwise may not converge)
self.x1.assign(self.field_lhs[cycle_stage])
# Solve problem, placing solution in self.x1
self.solver[0].solve()

Expand All @@ -445,7 +445,8 @@ def solve_stage(self, x0, stage):
evaluate(self.field_rhs, self.original_dt)
if self.limiter is not None:
self.limiter.apply(self.field_rhs)

# Use x0 as a first guess (otherwise may not converge)
self.x1.assign(x0)
# Solve problem, placing solution in self.x1
self.solver[1].solve()

Expand Down
107 changes: 81 additions & 26 deletions gusto/time_discretisation/imex_runge_kutta.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ class IMEXRungeKutta(TimeDiscretisation):
# --------------------------------------------------------------------------

def __init__(self, domain, butcher_imp, butcher_exp, field_name=None,
solver_parameters=None, limiter=None, options=None):
linear_solver_parameters=None, nonlinear_solver_parameters=None,
limiter=None, options=None):
"""
Args:
domain (:class:`Domain`): the model's domain object, containing the
Expand All @@ -73,20 +74,39 @@ def __init__(self, domain, butcher_imp, butcher_exp, field_name=None,
Runge Kutta time discretisation.
field_name (str, optional): name of the field to be evolved.
Defaults to None.
solver_parameters (dict, optional): dictionary of parameters to
pass to the underlying solver. Defaults to None.
linear_solver_parameters (dict, optional): dictionary of parameters to
pass to the underlying linear solver. Defaults to None.
nonlinear_solver_parameters (dict, optional): dictionary of parameters to
pass to the underlying nonlinear solver. Defaults to None.
options (:class:`AdvectionOptions`, optional): an object containing
options to either be passed to the spatial discretisation, or
to control the "wrapper" methods, such as Embedded DG or a
recovery method. Defaults to None.
"""
super().__init__(domain, field_name=field_name,
solver_parameters=solver_parameters,
solver_parameters=nonlinear_solver_parameters,
options=options)
self.butcher_imp = butcher_imp
self.butcher_exp = butcher_exp
self.nStages = int(np.shape(self.butcher_imp)[1])

# Set default linear and nonlinear solver options if none passed in
if linear_solver_parameters is None:
self.linear_solver_parameters = {'snes_type': 'ksponly',
'ksp_type': 'cg',
'pc_type': 'bjacobi',
'sub_pc_type': 'ilu'}
else:
self.linear_solver_parameters = linear_solver_parameters

if nonlinear_solver_parameters is None:
self.nonlinear_solver_parameters = {'snes_type': 'newtonls',
'ksp_type': 'gmres',
'pc_type': 'bjacobi',
'sub_pc_type': 'ilu'}
else:
self.nonlinear_solver_parameters = nonlinear_solver_parameters

def setup(self, equation, apply_bcs=True, *active_labels):
"""
Set up the time discretisation based on the equation.
Expand Down Expand Up @@ -200,7 +220,7 @@ def solvers(self):
# setup solver using residual defined in derived class
problem = NonlinearVariationalProblem(self.res(stage), self.x_out, bcs=self.bcs)
solver_name = self.field_name+self.__class__.__name__ + "%s" % (stage)
solvers.append(NonlinearVariationalSolver(problem, solver_parameters=self.solver_parameters, options_prefix=solver_name))
solvers.append(NonlinearVariationalSolver(problem, solver_parameters=self.nonlinear_solver_parameters, options_prefix=solver_name))
return solvers

@cached_property
Expand All @@ -209,19 +229,31 @@ def final_solver(self):
# setup solver using lhs and rhs defined in derived class
problem = NonlinearVariationalProblem(self.final_res, self.x_out, bcs=self.bcs)
solver_name = self.field_name+self.__class__.__name__
return NonlinearVariationalSolver(problem, solver_parameters=self.solver_parameters, options_prefix=solver_name)
return NonlinearVariationalSolver(problem, solver_parameters=self.linear_solver_parameters, options_prefix=solver_name)

@wrapper_apply
def apply(self, x_out, x_in):
self.x1.assign(x_in)
self.x_out.assign(x_in)
solver_list = self.solvers

for stage in range(self.nStages):
self.solver = solver_list[stage]
# Set initial solver guess
if (stage > 0):
self.x_out.assign(self.xs[stage-1])
self.solver.solve()

# Apply limiter
if self.limiter is not None:
self.limiter.apply(self.x_out)
self.xs[stage].assign(self.x_out)

self.final_solver.solve()

# Apply limiter
if self.limiter is not None:
self.limiter.apply(self.x_out)
x_out.assign(self.x_out)


Expand All @@ -236,16 +268,19 @@ class IMEX_Euler(IMEXRungeKutta):
y_1 = y^n + dt*F[y_1] + dt*S[y_0] \n
y^(n+1) = y^n + dt*F[y_1] + dt*S[y_0]
"""
def __init__(self, domain, field_name=None, solver_parameters=None,
def __init__(self, domain, field_name=None,
linear_solver_parameters=None, nonlinear_solver_parameters=None,
limiter=None, options=None):
"""
Args:
domain (:class:`Domain`): the model's domain object, containing the
mesh and the compatible function spaces.
field_name (str, optional): name of the field to be evolved.
Defaults to None.
solver_parameters (dict, optional): dictionary of parameters to
pass to the underlying solver. Defaults to None.
linear_solver_parameters (dict, optional): dictionary of parameters to
pass to the underlying linear solver. Defaults to None.
nonlinear_solver_parameters (dict, optional): dictionary of parameters to
pass to the underlying nonlinear solver. Defaults to None.
limiter (:class:`Limiter` object, optional): a limiter to apply to
the evolving field to enforce monotonicity. Defaults to None.
options (:class:`AdvectionOptions`, optional): an object containing
Expand All @@ -256,7 +291,8 @@ def __init__(self, domain, field_name=None, solver_parameters=None,
butcher_imp = np.array([[0., 0.], [0., 1.], [0., 1.]])
butcher_exp = np.array([[0., 0.], [1., 0.], [1., 0.]])
super().__init__(domain, butcher_imp, butcher_exp, field_name,
solver_parameters=solver_parameters,
linear_solver_parameters=linear_solver_parameters,
nonlinear_solver_parameters=nonlinear_solver_parameters,
limiter=limiter, options=options)


Expand All @@ -276,16 +312,19 @@ class IMEX_ARS3(IMEXRungeKutta):
y^(n+1) = y^n + dt*(g*F[y_1]+(1-g)*F[y_2]) \n
+ dt*(0.5*S[y_1]+0.5*S[y_2])
"""
def __init__(self, domain, field_name=None, solver_parameters=None,
def __init__(self, domain, field_name=None,
linear_solver_parameters=None, nonlinear_solver_parameters=None,
limiter=None, options=None):
"""
Args:
domain (:class:`Domain`): the model's domain object, containing the
mesh and the compatible function spaces.
field_name (str, optional): name of the field to be evolved.
Defaults to None.
solver_parameters (dict, optional): dictionary of parameters to
pass to the underlying solver. Defaults to None.
linear_solver_parameters (dict, optional): dictionary of parameters to
pass to the underlying linear solver. Defaults to None.
nonlinear_solver_parameters (dict, optional): dictionary of parameters to
pass to the underlying nonlinear solver. Defaults to None.
limiter (:class:`Limiter` object, optional): a limiter to apply to
the evolving field to enforce monotonicity. Defaults to None.
options (:class:`AdvectionOptions`, optional): an object containing
Expand All @@ -298,7 +337,8 @@ def __init__(self, domain, field_name=None, solver_parameters=None,
butcher_exp = np.array([[0., 0., 0.], [g, 0., 0.], [g-1., 2.*(1.-g), 0.], [0., 0.5, 0.5]])

super().__init__(domain, butcher_imp, butcher_exp, field_name,
solver_parameters=solver_parameters,
linear_solver_parameters=linear_solver_parameters,
nonlinear_solver_parameters=nonlinear_solver_parameters,
limiter=limiter, options=options)


Expand All @@ -318,15 +358,19 @@ class IMEX_ARK2(IMEXRungeKutta):
y^(n+1) = y^n + dt*(d*F[y_0]+d*F[y_1]+g*F[y_2]) \n
+ dt*(d*S[y_0]+d*S[y_1]+g*S[y_2])
"""
def __init__(self, domain, field_name=None, solver_parameters=None, limiter=None, options=None):
def __init__(self, domain, field_name=None,
linear_solver_parameters=None, nonlinear_solver_parameters=None,
limiter=None, options=None):
"""
Args:
domain (:class:`Domain`): the model's domain object, containing the
mesh and the compatible function spaces.
field_name (str, optional): name of the field to be evolved.
Defaults to None.
solver_parameters (dict, optional): dictionary of parameters to
pass to the underlying solver. Defaults to None.
linear_solver_parameters (dict, optional): dictionary of parameters to
pass to the underlying linear solver. Defaults to None.
nonlinear_solver_parameters (dict, optional): dictionary of parameters to
pass to the underlying nonlinear solver. Defaults to None.
limiter (:class:`Limiter` object, optional): a limiter to apply to
the evolving field to enforce monotonicity. Defaults to None.
options (:class:`AdvectionOptions`, optional): an object containing
Expand All @@ -340,7 +384,8 @@ def __init__(self, domain, field_name=None, solver_parameters=None, limiter=None
butcher_imp = np.array([[0., 0., 0.], [g, g, 0.], [d, d, g], [d, d, g]])
butcher_exp = np.array([[0., 0., 0.], [2.*g, 0., 0.], [1.-a, a, 0.], [d, d, g]])
super().__init__(domain, butcher_imp, butcher_exp, field_name,
solver_parameters=solver_parameters,
linear_solver_parameters=linear_solver_parameters,
nonlinear_solver_parameters=nonlinear_solver_parameters,
limiter=limiter, options=options)


Expand All @@ -358,15 +403,19 @@ class IMEX_SSP3(IMEXRungeKutta):
y^(n+1) = y^n + dt*(1/6*F[y_1]+1/6*F[y_2]+2/3*F[y_3]) \n
+ dt*(1/6*S[y_1]+1/6*S[y_2]+2/3*S[y_3])
"""
def __init__(self, domain, field_name=None, solver_parameters=None, limiter=None, options=None):
def __init__(self, domain, field_name=None,
linear_solver_parameters=None, nonlinear_solver_parameters=None,
limiter=None, options=None):
"""
Args:
domain (:class:`Domain`): the model's domain object, containing the
mesh and the compatible function spaces.
field_name (str, optional): name of the field to be evolved.
Defaults to None.
solver_parameters (dict, optional): dictionary of parameters to
pass to the underlying solver. Defaults to None.
linear_solver_parameters (dict, optional): dictionary of parameters to
pass to the underlying linear solver. Defaults to None.
nonlinear_solver_parameters (dict, optional): dictionary of parameters to
pass to the underlying nonlinear solver. Defaults to None.
limiter (:class:`Limiter` object, optional): a limiter to apply to
the evolving field to enforce monotonicity. Defaults to None.
options (:class:`AdvectionOptions`, optional): an object containing
Expand All @@ -378,7 +427,8 @@ def __init__(self, domain, field_name=None, solver_parameters=None, limiter=None
butcher_imp = np.array([[g, 0., 0.], [1-2.*g, g, 0.], [0.5-g, 0., g], [(1./6.), (1./6.), (2./3.)]])
butcher_exp = np.array([[0., 0., 0.], [1., 0., 0.], [0.25, 0.25, 0.], [(1./6.), (1./6.), (2./3.)]])
super().__init__(domain, butcher_imp, butcher_exp, field_name,
solver_parameters=solver_parameters,
linear_solver_parameters=linear_solver_parameters,
nonlinear_solver_parameters=nonlinear_solver_parameters,
limiter=limiter, options=options)


Expand All @@ -396,15 +446,19 @@ class IMEX_Trap2(IMEXRungeKutta):
y_3 = y^n + dt*(0.5*F[y_0]+0.5*F[y_3]) + dt*(0.5*S[y_0]+0.5*S[y_2]) \n
y^(n+1) = y^n + dt*(0.5*F[y_0]+0.5*F[y_3]) + dt*(0.5*S[y_0] + 0.5*S[y_2]) \n
"""
def __init__(self, domain, field_name=None, solver_parameters=None, limiter=None, options=None):
def __init__(self, domain, field_name=None,
linear_solver_parameters=None, nonlinear_solver_parameters=None,
limiter=None, options=None):
"""
Args:
domain (:class:`Domain`): the model's domain object, containing the
mesh and the compatible function spaces.
field_name (str, optional): name of the field to be evolved.
Defaults to None.
solver_parameters (dict, optional): dictionary of parameters to
pass to the underlying solver. Defaults to None.
linear_solver_parameters (dict, optional): dictionary of parameters to
pass to the underlying linear solver. Defaults to None.
nonlinear_solver_parameters (dict, optional): dictionary of parameters to
pass to the underlying nonlinear solver. Defaults to None.
limiter (:class:`Limiter` object, optional): a limiter to apply to
the evolving field to enforce monotonicity. Defaults to None.
options (:class:`AdvectionOptions`, optional): an object containing
Expand All @@ -416,5 +470,6 @@ def __init__(self, domain, field_name=None, solver_parameters=None, limiter=None
butcher_imp = np.array([[0., 0., 0., 0.], [e, 0., 0., 0.], [0.5, 0., 0.5, 0.], [0.5, 0., 0., 0.5], [0.5, 0., 0., 0.5]])
butcher_exp = np.array([[0., 0., 0., 0.], [1., 0., 0., 0.], [0.5, 0.5, 0., 0.], [0.5, 0., 0.5, 0.], [0.5, 0., 0.5, 0.]])
super().__init__(domain, butcher_imp, butcher_exp, field_name,
solver_parameters=solver_parameters,
linear_solver_parameters=linear_solver_parameters,
nonlinear_solver_parameters=nonlinear_solver_parameters,
limiter=limiter, options=options)
Loading

0 comments on commit 2a929f8

Please sign in to comment.