Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensemble parallelism #505

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion gusto/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __setattr__(self, name, value):

When attributes are provided as floats or integers, these are converted
to Firedrake :class:`Constant` objects, other than a handful of special
integers (dumpfreq, pddumpfreq, chkptfreq and log_level).
integers (dumpfreq, pddumpfreq and chkptfreq).

Args:
name: the attribute's name.
Expand Down
3 changes: 2 additions & 1 deletion gusto/equations.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@ def __init__(self, domain, function_space, field_name, Vu=None,
equation's prognostic is defined on.
field_name (str): name of the prognostic field.
Vu (:class:`FunctionSpace`, optional): the function space for the
velocity field. If this is Defaults to None.
velocity field. Defaults to None in which case use the
HDiv space.
diffusion_parameters (:class:`DiffusionParameters`, optional):
parameters describing the diffusion to be applied.
"""
Expand Down
57 changes: 42 additions & 15 deletions gusto/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ def dump(self, field_creator, t):

class DiagnosticsOutput(object):
"""Object for outputting global diagnostic data."""
def __init__(self, filename, diagnostics, description, comm, create=True):
def __init__(self, filename, diagnostics, description, comm,
ensemble_comm=None, create=True):
"""
Args:
filename (str): name of file to output to.
Expand All @@ -155,9 +156,14 @@ def __init__(self, filename, diagnostics, description, comm, create=True):
self.filename = filename
self.diagnostics = diagnostics
self.comm = comm
self.ensemble_comm = ensemble_comm
if ensemble_comm is not None:
self.write_to_file = ensemble_comm.rank == 0 and comm.rank == 0
else:
self.write_to_file = comm.rank == 0
if not create:
return
if self.comm.rank == 0:
if self.write_to_file:
with Dataset(filename, "w") as dataset:
dataset.description = "Diagnostics data for simulation {desc}".format(desc=description)
dataset.history = "Created {t}".format(t=time.ctime())
Expand Down Expand Up @@ -185,7 +191,7 @@ def dump(self, state_fields, t):
diagnostic = getattr(self.diagnostics, dname)
diagnostics.append((fname, dname, diagnostic(field)))

if self.comm.rank == 0:
if self.write_to_file:
with Dataset(self.filename, "a") as dataset:
idx = dataset.dimensions["time"].size
dataset.variables["time"][idx:idx + 1] = t
Expand Down Expand Up @@ -354,7 +360,7 @@ def setup_diagnostics(self, state_fields):
if fname in state_fields.to_dump:
self.diagnostics.register(fname)

def setup_dump(self, state_fields, t, pick_up=False):
def setup_dump(self, state_fields, t, ensemble, pick_up=False):
"""
Sets up a series of things used for outputting.

Expand All @@ -377,6 +383,18 @@ def setup_dump(self, state_fields, t, pick_up=False):
raise_parallel_exception = 0
error = None

if ensemble is not None:
ens_comm = ensemble.ensemble_comm
comm = ensemble.comm
create_dir = ens_comm.rank + comm.rank == 0
create_files = ens_comm.rank == 0
else:
ens_comm = None
comm = self.mesh.comm
create_dir = comm.Get_rank() == 0
create_files = True
self.ensemble = ensemble

if any([self.output.dump_vtus, self.output.dump_nc,
self.output.dumplist_latlon, self.output.dump_diagnostics,
self.output.point_data, self.output.checkpoint and not pick_up]):
Expand All @@ -385,7 +403,7 @@ def setup_dump(self, state_fields, t, pick_up=False):
running_tests = '--running-tests' in sys.argv or "pytest" in self.output.dirname

# Raising exceptions needs to be done in parallel
if self.mesh.comm.Get_rank() == 0:
if create_dir:
# Create results directory if it doesn't already exist
if not path.exists(self.dumpdir):
try:
Expand All @@ -400,7 +418,10 @@ def setup_dump(self, state_fields, t, pick_up=False):

# Gather errors from each rank and raise appropriate error everywhere
# This allreduce also ensures that all ranks are in sync wrt the results dir
raise_exception = self.mesh.comm.allreduce(raise_parallel_exception, op=MPI.MAX)
raise_exception = comm.allreduce(raise_parallel_exception, op=MPI.MAX)
if ensemble is not None:
raise_exception = ens_comm.allreduce(raise_exception, op=MPI.MAX)

if raise_exception == 1:
raise GustoIOError(f'results directory {self.dumpdir} already exists')
elif raise_exception == 2:
Expand All @@ -421,12 +442,12 @@ def setup_dump(self, state_fields, t, pick_up=False):
if pick_up:
next(self.dumpcount)

if self.output.dump_vtus:
if self.output.dump_vtus and create_files:
# setup pvd output file
outfile_pvd = path.join(self.dumpdir, "field_output.pvd")
self.pvd_dumpfile = VTKFile(
outfile_pvd, project_output=self.output.project_fields,
comm=self.mesh.comm)
comm=comm)

if self.output.dump_nc:
self.nc_filename = path.join(self.dumpdir, "field_output.nc")
Expand All @@ -453,10 +474,11 @@ def setup_dump(self, state_fields, t, pick_up=False):
# setup the latlon coordinate mesh and make output file
if len(self.output.dumplist_latlon) > 0:
mesh_ll = get_flat_latlon_mesh(self.mesh)
outfile_ll = path.join(self.dumpdir, "field_output_latlon.pvd")
self.dumpfile_ll = VTKFile(outfile_ll,
project_output=self.output.project_fields,
comm=self.mesh.comm)
if create_files:
outfile_ll = path.join(self.dumpdir, "field_output_latlon.pvd")
self.dumpfile_ll = VTKFile(outfile_ll,
project_output=self.output.project_fields,
comm=comm)

# make functions on latlon mesh, as specified by dumplist_latlon
self.to_dump_latlon = []
Expand All @@ -472,11 +494,12 @@ def setup_dump(self, state_fields, t, pick_up=False):
# already exist, in which case we just need the filenames
if self.output.dump_diagnostics:
diagnostics_filename = self.dumpdir+"/diagnostics.nc"
to_create = not (path.isfile(diagnostics_filename) and pick_up)
to_create = not (path.isfile(diagnostics_filename) and pick_up) and create_files
self.diagnostic_output = DiagnosticsOutput(diagnostics_filename,
self.diagnostics,
self.output.dirname,
self.mesh.comm,
comm=comm,
ensemble_comm=ens_comm,
create=to_create)

# if picking-up, don't do initial dump
Expand Down Expand Up @@ -665,6 +688,10 @@ def dump(self, state_fields, t, step, initial_steps=None):
completed by a multi-level time scheme. Defaults to None.
"""
output = self.output
if self.ensemble is not None:
write_file = self.ensemble.ensemble_comm.rank == 0
else:
write_file = True

# Diagnostics:
# Compute diagnostic fields
Expand Down Expand Up @@ -703,7 +730,7 @@ def dump(self, state_fields, t, step, initial_steps=None):
# dump fields
self.write_nc_dump(t)

if output.dump_vtus:
if output.dump_vtus and write_file:
# dump fields
self.pvd_dumpfile.write(*self.to_dump)

Expand Down
10 changes: 6 additions & 4 deletions gusto/timeloop.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
class BaseTimestepper(object, metaclass=ABCMeta):
"""Base class for timesteppers."""

def __init__(self, equation, io):
def __init__(self, equation, io, ensemble=None):
"""
Args:
equation (:class:`PrognosticEquation`): the prognostic equation.
Expand All @@ -33,6 +33,7 @@ def __init__(self, equation, io):

self.equation = equation
self.io = io
self.ensemble = ensemble
self.dt = self.equation.domain.dt
self.t = self.equation.domain.t
self.reference_profiles_initialised = False
Expand Down Expand Up @@ -177,7 +178,7 @@ def run(self, t, tmax, pick_up=False):

# Set up dump, which may also include an initial dump
with timed_stage("Dump output"):
self.io.setup_dump(self.fields, t, pick_up)
self.io.setup_dump(self.fields, t, self.ensemble, pick_up)

self.t.assign(t)

Expand Down Expand Up @@ -249,7 +250,7 @@ class Timestepper(BaseTimestepper):
"""

def __init__(self, equation, scheme, io, spatial_methods=None,
physics_parametrisations=None):
physics_parametrisations=None, ensemble=None):
"""
Args:
equation (:class:`PrognosticEquation`): the prognostic equation
Expand Down Expand Up @@ -284,7 +285,7 @@ def __init__(self, equation, scheme, io, spatial_methods=None,
else:
self.physics_parametrisations = []

super().__init__(equation=equation, io=io)
super().__init__(equation=equation, io=io, ensemble=ensemble)

@property
def transporting_velocity(self):
Expand Down Expand Up @@ -716,6 +717,7 @@ def timestep(self):

with timed_stage("Apply forcing terms"):
logger.info('SIQN: Explicit forcing')

# Put explicit forcing into xstar
self.forcing.apply(x_after_slow, xn, xstar(self.field_name), "explicit")

Expand Down
32 changes: 32 additions & 0 deletions integration-tests/model/test_parallel_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from firedrake import Ensemble, COMM_WORLD, PeriodicUnitSquareMesh
from gusto import *
import pytest


@pytest.mark.parallel(nprocs=4)
@pytest.mark.parametrize("spatial_parallelism", [True, False])
def test_parallel_io(tmpdir, spatial_parallelism):

if spatial_parallelism:
ensemble = Ensemble(COMM_WORLD, 2)
else:
ensemble = Ensemble(COMM_WORLD, 1)

mesh = PeriodicUnitSquareMesh(10, 10, comm=ensemble.comm)
dt = 0.1
domain = Domain(mesh, dt, "BDM", 1)

# Equation
parameters = ShallowWaterParameters(H=100)
equation = ShallowWaterEquations(domain, parameters)

# I/O
output = OutputParameters(dirname=str(tmpdir))
io = IO(domain, output)

# Time stepper
spatial_methods = [DGUpwind(equation, "u"), DGUpwind(equation, "D")]
stepper = Timestepper(equation, SSPRK3(domain), io, spatial_methods,
ensemble=ensemble)

stepper.run(0, 3*dt)
Loading