Skip to content

Commit

Permalink
Add parameter list support to JAX solver (permitting multithreading/G…
Browse files Browse the repository at this point in the history
…PU execution) (pybamm-team#3121)

* Allow parameter lists to Jax solver

* Use Jax (experimental) sparse matrices

* Add parameter list solver to JAX with multithreaded and gpu support

* style: pre-commit fixes

* Add changelog entry

* Installer now specifies Jax minor version only

* Revert inclusion of experimental jax sparse matrices

* Permit multiple inputs in Jax unit tests

* Only run Jax tests when Jax installed

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: John Brittain <[email protected]>
  • Loading branch information
3 people authored Sep 19, 2023
1 parent 8da245e commit a68c038
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 64 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
## Optimizations

- Improved how steps are processed in simulations to reduce memory usage ([#3261](https://github.com/pybamm-team/PyBaMM/pull/3261))
- Added parameter list support to JAX solver, permitting multithreading / GPU execution ([#3121](https://github.com/pybamm-team/PyBaMM/pull/3121))

## Breaking changes

Expand Down
1 change: 0 additions & 1 deletion pybamm/expression_tree/operations/evaluate_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ def find_symbols(symbol, constant_symbols, variable_symbols, output_jax=False):
if output_jax and scipy.sparse.issparse(value):
# convert any remaining sparse matrices to our custom coo matrix
constant_symbols[symbol.id] = create_jax_coo_matrix(value)

else:
constant_symbols[symbol.id] = value
return
Expand Down
35 changes: 18 additions & 17 deletions pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,13 +748,6 @@ def solve(
self._set_up_model_inputs(model, inputs) for inputs in inputs_list
]

# Cannot use multiprocessing with model in "jax" format
if (len(inputs_list) > 1) and model.convert_to_format == "jax":
raise pybamm.SolverError(
"Cannot solve list of inputs with multiprocessing "
'when model in format "jax".'
)

# Check that calculate_sensitivites have not been updated
calculate_sensitivities_list.sort()
if not hasattr(model, "calculate_sensitivities"):
Expand Down Expand Up @@ -864,17 +857,25 @@ def solve(
)
new_solutions = [new_solution]
else:
with mp.Pool(processes=nproc) as p:
new_solutions = p.starmap(
self._integrate,
zip(
[model] * ninputs,
[t_eval[start_index:end_index]] * ninputs,
model_inputs_list,
),
if model.convert_to_format == "jax":
# Jax can parallelize over the inputs efficiently
new_solutions = self._integrate(
model,
t_eval[start_index:end_index],
model_inputs_list,
)
p.close()
p.join()
else:
with mp.Pool(processes=nproc) as p:
new_solutions = p.starmap(
self._integrate,
zip(
[model] * ninputs,
[t_eval[start_index:end_index]] * ninputs,
model_inputs_list,
),
)
p.close()
p.join()
# Setting the solve time for each segment.
# pybamm.Solution.__add__ assumes attribute solve_time.
solve_time = timer.time()
Expand Down
98 changes: 88 additions & 10 deletions pybamm/solvers/jax_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Solver class using Scipy's adaptive time stepper
#
import numpy as onp
import asyncio

import pybamm

Expand Down Expand Up @@ -164,7 +165,7 @@ def solve_model_rk45(inputs):
inputs,
rtol=self.rtol,
atol=self.atol,
**self.extra_options
**self.extra_options,
)
return jnp.transpose(y)

Expand All @@ -177,7 +178,7 @@ def solve_model_bdf(inputs):
rtol=self.rtol,
atol=self.atol,
mass=mass,
**self.extra_options
**self.extra_options,
)
return jnp.transpose(y)

Expand All @@ -186,7 +187,7 @@ def solve_model_bdf(inputs):
else:
return jax.jit(solve_model_bdf)

def _integrate(self, model, t_eval, inputs_dict=None):
def _integrate(self, model, t_eval, inputs=None):
"""
Solve a model defined by dydt with initial conditions y0.
Expand All @@ -196,7 +197,7 @@ def _integrate(self, model, t_eval, inputs_dict=None):
The model whose solution to calculate.
t_eval : :class:`numpy.array`, size (k,)
The times at which to compute the solution
inputs_dict : dict, optional
inputs : dict, list[dict], optional
Any input parameters to pass to the model when solving
Returns
Expand All @@ -206,11 +207,74 @@ def _integrate(self, model, t_eval, inputs_dict=None):
various diagnostic messages.
"""
if isinstance(inputs, dict):
inputs = [inputs]
timer = pybamm.Timer()
if model not in self._cached_solves:
self._cached_solves[model] = self.create_solve(model, t_eval)

y = self._cached_solves[model](inputs_dict).block_until_ready()
y = []
platform = jax.lib.xla_bridge.get_backend().platform.casefold()
if platform.startswith("cpu"):
# cpu execution runs faster when multithreaded
async def solve_model_for_inputs():
async def solve_model_async(inputs_v):
return self._cached_solves[model](inputs_v)

coro = []
for inputs_v in inputs:
coro.append(asyncio.create_task(solve_model_async(inputs_v)))
return await asyncio.gather(*coro)

y = asyncio.run(solve_model_for_inputs())
elif platform.startswith("gpu") or platform.startswith("tpu"):
# gpu execution runs faster when parallelised with vmap
# (see also comment below regarding single-program multiple-data
# execution (SPMD) using pmap on multiple XLAs)

# convert inputs (array of dict) to a dict of arrays for vmap
inputs_v = {
key: jnp.array([dic[key] for dic in inputs]) for key in inputs[0]
}
y.extend(jax.vmap(self._cached_solves[model])(inputs_v))
else:
# Unknown platform, use serial execution as fallback
print(
f'Unknown platform requested: "{platform}", '
"falling back to serial execution"
)
for inputs_v in inputs:
y.append(self._cached_solves[model](inputs_v))

# This code block implements single-program multiple-data execution
# using pmap across multiple XLAs. It is currently commented out
# because it produces bus errors for even moderate-sized models.
# It is suspected that this is due to either a bug in JAX, insufficient
# sparse matrix support in JAX resulting in high memory usage, or a bug
# in the BDF solver.
#
# This issue on guthub appears related:
# https://github.com/google/jax/discussions/13930
#
# # Split input list based on the number of available xla devices
# device_count = jax.local_device_count()
# inputs_listoflists = [inputs[x:x + device_count]
# for x in range(0, len(inputs), device_count)]
# if len(inputs_listoflists) > 1:
# print(f"{len(inputs)} parameter sets were provided, "
# f"but only {device_count} XLA devices are available")
# print(f"Parameter sets split into {len(inputs_listoflists)} "
# "lists for parallel processing")
# y = []
# for k, inputs_list in enumerate(inputs_listoflists):
# if len(inputs_listoflists) > 1:
# print(f" Solving list {k+1} of {len(inputs_listoflists)} "
# f"({len(inputs_list)} parameter sets)")
# # convert inputs to a dict of arrays for pmap
# inputs_v = {key: jnp.array([dic[key] for dic in inputs_list])
# for key in inputs_list[0]}
# y.extend(jax.pmap(self._cached_solves[model])(inputs_v))

integration_time = timer.time()

# convert to a normal numpy array
Expand All @@ -219,8 +283,22 @@ def _integrate(self, model, t_eval, inputs_dict=None):
termination = "final time"
t_event = None
y_event = onp.array(None)
sol = pybamm.Solution(
t_eval, y, model, inputs_dict, t_event, y_event, termination
)
sol.integration_time = integration_time
return sol

# Extract solutions from y with their associated input dicts
solutions = []
for k, inputs_dict in enumerate(inputs):
sol = pybamm.Solution(
t_eval,
jnp.reshape(y[k,], y.shape[1:]),
model,
inputs_dict,
t_event,
y_event,
termination,
)
sol.integration_time = integration_time
solutions.append(sol)

if len(solutions) == 1:
return solutions[0]
return solutions
15 changes: 7 additions & 8 deletions pybamm/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
import pybamm

# versions of jax and jaxlib compatible with PyBaMM
JAX_VERSION = "0.4.8"
JAXLIB_VERSION = "0.4.7"
JAX_VERSION = "0.4"
JAXLIB_VERSION = "0.4"


def root_dir():
Expand Down Expand Up @@ -271,10 +271,9 @@ def have_jax():

def is_jax_compatible():
"""Check if the available version of jax and jaxlib are compatible with PyBaMM"""
return (
pkg_resources.get_distribution("jax").version == JAX_VERSION
and pkg_resources.get_distribution("jaxlib").version == JAXLIB_VERSION
)
return pkg_resources.get_distribution("jax").version.startswith(
JAX_VERSION
) and pkg_resources.get_distribution("jaxlib").version.startswith(JAXLIB_VERSION)


def is_constant_and_can_evaluate(symbol):
Expand Down Expand Up @@ -341,7 +340,7 @@ def install_jax(arguments=None): # pragma: no cover
"-m",
"pip",
"install",
f"jax=={JAX_VERSION}",
f"jaxlib=={JAXLIB_VERSION}",
f"jax>={JAX_VERSION}",
f"jaxlib>={JAXLIB_VERSION}",
]
)
2 changes: 1 addition & 1 deletion tests/testcase.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def FixRandomSeed(method):
explicitely reinstate the random seed within their method bodies as desired,
e.g. by calling np.random.seed(None) to restore normal behaviour.
Generatig a random seed from the method name allows particularly awkward
Generating a random seed from the method name allows particularly awkward
sequences to be altered by changing the method name, such as by adding a
trailing underscore, or other hash modifier, if required.
"""
Expand Down
53 changes: 26 additions & 27 deletions tests/unit/test_solvers/test_scipy_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,35 +338,34 @@ def test_model_solver_multiple_inputs_initial_conditions_error(self):
):
solver.solve(model, t_eval, inputs=inputs_list, nproc=2)

def test_model_solver_multiple_inputs_jax_format_error(self):
# Create model
model = pybamm.BaseModel()
model.convert_to_format = "jax"
domain = ["negative electrode", "separator", "positive electrode"]
var = pybamm.Variable("var", domain=domain)
model.rhs = {var: -pybamm.InputParameter("rate") * var}
model.initial_conditions = {var: 2 * pybamm.InputParameter("rate")}
# No need to set parameters; can use base discretisation (no spatial
# operators)
# create discretisation
mesh = get_mesh_for_testing()
spatial_methods = {"macroscale": pybamm.FiniteVolume()}
disc = pybamm.Discretisation(mesh, spatial_methods)
disc.process_model(model)
def test_model_solver_multiple_inputs_jax_format(self):
if pybamm.have_jax():
# Create model
model = pybamm.BaseModel()
model.convert_to_format = "jax"
domain = ["negative electrode", "separator", "positive electrode"]
var = pybamm.Variable("var", domain=domain)
model.rhs = {var: -pybamm.InputParameter("rate") * var}
model.initial_conditions = {var: 1}
# create discretisation
mesh = get_mesh_for_testing()
spatial_methods = {"macroscale": pybamm.FiniteVolume()}
disc = pybamm.Discretisation(mesh, spatial_methods)
disc.process_model(model)

solver = pybamm.ScipySolver(rtol=1e-8, atol=1e-8, method="RK45")
t_eval = np.linspace(0, 10, 100)
ninputs = 8
inputs_list = [{"rate": 0.01 * (i + 1)} for i in range(ninputs)]
solver = pybamm.JaxSolver(rtol=1e-8, atol=1e-8, method="RK45")
t_eval = np.linspace(0, 10, 100)
ninputs = 8
inputs_list = [{"rate": 0.01 * (i + 1)} for i in range(ninputs)]

with self.assertRaisesRegex(
pybamm.SolverError,
(
"Cannot solve list of inputs with multiprocessing "
'when model in format "jax".'
),
):
solver.solve(model, t_eval, inputs=inputs_list, nproc=2)
solutions = solver.solve(model, t_eval, inputs=inputs_list, nproc=2)
for i in range(ninputs):
with self.subTest(i=i):
solution = solutions[i]
np.testing.assert_array_equal(solution.t, t_eval)
np.testing.assert_allclose(
solution.y[0], np.exp(-0.01 * (i + 1) * solution.t)
)

def test_model_solver_with_event_with_casadi(self):
# Create model
Expand Down

0 comments on commit a68c038

Please sign in to comment.