Skip to content

Commit

Permalink
Standardise BaseSolver.solve() multiprocessing context to fork (pyb…
Browse files Browse the repository at this point in the history
…amm-team#3975)

* Enforce multiprocessing context fork, add example & updt integration test

* updt changelog

* Add OS conditional context

---------

Co-authored-by: Eric G. Kratz <[email protected]>
  • Loading branch information
BradyPlanden and kratman authored Apr 10, 2024
1 parent f050d58 commit b068ba6
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Features

- Updates multiprocess `Pool` in `BaseSolver.solve()` to be constructed with context `fork`. Adds small example for multiprocess inputs. ([#3974](https://github.com/pybamm-team/PyBaMM/pull/3974))
- Added custom experiment steps ([#3835](https://github.com/pybamm-team/PyBaMM/pull/3835))
- Added support for macOS arm64 (M-series) platforms. ([#3789](https://github.com/pybamm-team/PyBaMM/pull/3789))
- Added the ability to specify a custom solver tolerance in `get_initial_stoichiometries` and related functions ([#3714](https://github.com/pybamm-team/PyBaMM/pull/3714))
Expand Down
18 changes: 18 additions & 0 deletions examples/scripts/multiprocess_inputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pybamm
import numpy as np

# create the model
model = pybamm.lithium_ion.DFN()

# set the default model parameters
param = model.default_parameter_values

# change the current function to be an input parameter
param["Current function [A]"] = "[input]"

simulation = pybamm.Simulation(model, parameter_values=param)

# solve the model at the given time points, passing multiple current values as inputs
t_eval = np.linspace(0, 600, 300)
inputs = [{"Current function [A]": x} for x in range(1, 3)]
sol = simulation.solve(t_eval, inputs=inputs)
5 changes: 5 additions & 0 deletions pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,18 @@
import numbers
import sys
import warnings
import platform

import casadi
import numpy as np

import pybamm
from pybamm.expression_tree.binary_operators import _Heaviside

# Set context for parallel processing depending on the platform
if platform.system() == "Darwin" or platform.system() == "Linux":
mp.set_start_method("fork")


class BaseSolver:
"""Solve a discretised model.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ def compare_outputs_two_phase_silicon_graphite(self, model_class):

sim = pybamm.Simulation(model, parameter_values=param)
t_eval = np.linspace(0, 9000, 1000)
sol1 = sim.solve(t_eval, inputs={"x": 0.01})
sol2 = sim.solve(t_eval, inputs={"x": 0.1})
inputs = [{"x": 0.01}, {"x": 0.1}]
sol = sim.solve(t_eval, inputs=inputs)

# Starting values should be close
for var in [
Expand All @@ -155,11 +155,11 @@ def compare_outputs_two_phase_silicon_graphite(self, model_class):
"Average negative secondary particle concentration",
]:
np.testing.assert_allclose(
sol1[var].data[:20], sol2[var].data[:20], rtol=1e-2
sol[0][var].data[:20], sol[1][var].data[:20], rtol=1e-2
)

# More silicon means longer sim
self.assertLess(sol1["Time [s]"].data[-1], sol2["Time [s]"].data[-1])
self.assertLess(sol[0]["Time [s]"].data[-1], sol[1]["Time [s]"].data[-1])

def test_compare_SPM_silicon_graphite(self):
model_class = pybamm.lithium_ion.SPM
Expand Down

0 comments on commit b068ba6

Please sign in to comment.