Skip to content

Commit 48dbb68

Browse files
feat: add OpenMP parallelization to IDAKLU solver for lists of input parameters (#4449)
* new solver option `num_solvers`, indicates how many solves run in parallel * existing `num_threads` gives total number of threads which are distributed among `num_solvers`
1 parent e1118ec commit 48dbb68

20 files changed

+677
-256
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
## Features
44
- Added sensitivity calculation support for `pybamm.Simulation` and `pybamm.Experiment` ([#4415](https://github.com/pybamm-team/PyBaMM/pull/4415))
5+
- Added OpenMP parallelization to IDAKLU solver for lists of input parameters ([#4449](https://github.com/pybamm-team/PyBaMM/pull/4449))
56

67
## Optimizations
78
- Removed the `start_step_offset` setting and disabled minimum `dt` warnings for drive cycles with the (`IDAKLUSolver`). ([#4416](https://github.com/pybamm-team/PyBaMM/pull/4416))

CMakeLists.txt

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ endif()
1919

2020
project(idaklu)
2121

22-
set(CMAKE_CXX_STANDARD 14)
22+
set(CMAKE_CXX_STANDARD 17)
2323
set(CMAKE_CXX_STANDARD_REQUIRED ON)
2424
set(CMAKE_CXX_EXTENSIONS OFF)
2525
set(CMAKE_EXPORT_COMPILE_COMMANDS 1)
@@ -82,6 +82,8 @@ pybind11_add_module(idaklu
8282
src/pybamm/solvers/c_solvers/idaklu/idaklu_solver.hpp
8383
src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolver.cpp
8484
src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolver.hpp
85+
src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverGroup.cpp
86+
src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverGroup.hpp
8587
src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.inl
8688
src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.hpp
8789
src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP_solvers.cpp
@@ -94,6 +96,8 @@ pybind11_add_module(idaklu
9496
src/pybamm/solvers/c_solvers/idaklu/common.cpp
9597
src/pybamm/solvers/c_solvers/idaklu/Solution.cpp
9698
src/pybamm/solvers/c_solvers/idaklu/Solution.hpp
99+
src/pybamm/solvers/c_solvers/idaklu/SolutionData.cpp
100+
src/pybamm/solvers/c_solvers/idaklu/SolutionData.hpp
97101
src/pybamm/solvers/c_solvers/idaklu/Options.hpp
98102
src/pybamm/solvers/c_solvers/idaklu/Options.cpp
99103
# IDAKLU expressions / function evaluation [abstract]
@@ -138,6 +142,23 @@ set_target_properties(
138142
INSTALL_RPATH_USE_LINK_PATH TRUE
139143
)
140144

145+
# openmp
146+
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
147+
execute_process(
148+
COMMAND "brew" "--prefix"
149+
OUTPUT_VARIABLE HOMEBREW_PREFIX
150+
OUTPUT_STRIP_TRAILING_WHITESPACE)
151+
if (OpenMP_ROOT)
152+
set(OpenMP_ROOT "${OpenMP_ROOT}:${HOMEBREW_PREFIX}/opt/libomp")
153+
else()
154+
set(OpenMP_ROOT "${HOMEBREW_PREFIX}/opt/libomp")
155+
endif()
156+
endif()
157+
find_package(OpenMP)
158+
if(OpenMP_CXX_FOUND)
159+
target_link_libraries(idaklu PRIVATE OpenMP::OpenMP_CXX)
160+
endif()
161+
141162
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${PROJECT_SOURCE_DIR})
142163
# Sundials
143164
find_package(SUNDIALS REQUIRED)

src/pybamm/solvers/base_solver.py

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@ def supports_interp(self):
8686
def root_method(self):
8787
return self._root_method
8888

89+
@property
90+
def supports_parallel_solve(self):
91+
return False
92+
8993
@root_method.setter
9094
def root_method(self, method):
9195
if method == "casadi":
@@ -896,36 +900,37 @@ def solve(
896900
pybamm.logger.verbose(
897901
f"Calling solver for {t_eval[start_index]} < t < {t_eval[end_index - 1]}"
898902
)
899-
ninputs = len(model_inputs_list)
900-
if ninputs == 1:
901-
new_solution = self._integrate(
902-
model,
903-
t_eval[start_index:end_index],
904-
model_inputs_list[0],
905-
t_interp=t_interp,
906-
)
907-
new_solutions = [new_solution]
908-
elif model.convert_to_format == "jax":
909-
# Jax can parallelize over the inputs efficiently
903+
if self.supports_parallel_solve:
904+
# Jax and IDAKLU solver can accept a list of inputs
910905
new_solutions = self._integrate(
911906
model,
912907
t_eval[start_index:end_index],
913908
model_inputs_list,
914909
t_interp,
915910
)
916911
else:
917-
with mp.get_context(self._mp_context).Pool(processes=nproc) as p:
918-
new_solutions = p.starmap(
919-
self._integrate,
920-
zip(
921-
[model] * ninputs,
922-
[t_eval[start_index:end_index]] * ninputs,
923-
model_inputs_list,
924-
[t_interp] * ninputs,
925-
),
912+
ninputs = len(model_inputs_list)
913+
if ninputs == 1:
914+
new_solution = self._integrate(
915+
model,
916+
t_eval[start_index:end_index],
917+
model_inputs_list[0],
918+
t_interp=t_interp,
926919
)
927-
p.close()
928-
p.join()
920+
new_solutions = [new_solution]
921+
else:
922+
with mp.get_context(self._mp_context).Pool(processes=nproc) as p:
923+
new_solutions = p.starmap(
924+
self._integrate,
925+
zip(
926+
[model] * ninputs,
927+
[t_eval[start_index:end_index]] * ninputs,
928+
model_inputs_list,
929+
[t_interp] * ninputs,
930+
),
931+
)
932+
p.close()
933+
p.join()
929934
# Setting the solve time for each segment.
930935
# pybamm.Solution.__add__ assumes attribute solve_time.
931936
solve_time = timer.time()
@@ -995,7 +1000,7 @@ def solve(
9951000
)
9961001

9971002
# Return solution(s)
998-
if ninputs == 1:
1003+
if len(solutions) == 1:
9991004
return solutions[0]
10001005
else:
10011006
return solutions
@@ -1350,7 +1355,13 @@ def step(
13501355
# Step
13511356
pybamm.logger.verbose(f"Stepping for {t_start_shifted:.0f} < t < {t_end:.0f}")
13521357
timer.reset()
1353-
solution = self._integrate(model, t_eval, model_inputs, t_interp)
1358+
1359+
# API for _integrate is different for JaxSolver and IDAKLUSolver
1360+
if self.supports_parallel_solve:
1361+
solutions = self._integrate(model, t_eval, [model_inputs], t_interp)
1362+
solution = solutions[0]
1363+
else:
1364+
solution = self._integrate(model, t_eval, model_inputs, t_interp)
13541365
solution.solve_time = timer.time()
13551366

13561367
# Check if extrapolation occurred

src/pybamm/solvers/c_solvers/idaklu.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <pybind11/stl_bind.h>
1010

1111
#include "idaklu/idaklu_solver.hpp"
12+
#include "idaklu/IDAKLUSolverGroup.hpp"
1213
#include "idaklu/IdakluJax.hpp"
1314
#include "idaklu/common.hpp"
1415
#include "idaklu/Expressions/Casadi/CasadiFunctions.hpp"
@@ -26,15 +27,17 @@ casadi::Function generate_casadi_function(const std::string &data)
2627
namespace py = pybind11;
2728

2829
PYBIND11_MAKE_OPAQUE(std::vector<np_array>);
30+
PYBIND11_MAKE_OPAQUE(std::vector<Solution>);
2931

3032
PYBIND11_MODULE(idaklu, m)
3133
{
3234
m.doc() = "sundials solvers"; // optional module docstring
3335

3436
py::bind_vector<std::vector<np_array>>(m, "VectorNdArray");
37+
py::bind_vector<std::vector<Solution>>(m, "VectorSolution");
3538

36-
py::class_<IDAKLUSolver>(m, "IDAKLUSolver")
37-
.def("solve", &IDAKLUSolver::solve,
39+
py::class_<IDAKLUSolverGroup>(m, "IDAKLUSolverGroup")
40+
.def("solve", &IDAKLUSolverGroup::solve,
3841
"perform a solve",
3942
py::arg("t_eval"),
4043
py::arg("t_interp"),
@@ -43,8 +46,8 @@ PYBIND11_MODULE(idaklu, m)
4346
py::arg("inputs"),
4447
py::return_value_policy::take_ownership);
4548

46-
m.def("create_casadi_solver", &create_idaklu_solver<CasadiFunctions>,
47-
"Create a casadi idaklu solver object",
49+
m.def("create_casadi_solver_group", &create_idaklu_solver_group<CasadiFunctions>,
50+
"Create a group of casadi idaklu solver objects",
4851
py::arg("number_of_states"),
4952
py::arg("number_of_parameters"),
5053
py::arg("rhs_alg"),
@@ -70,8 +73,8 @@ PYBIND11_MODULE(idaklu, m)
7073
py::return_value_policy::take_ownership);
7174

7275
#ifdef IREE_ENABLE
73-
m.def("create_iree_solver", &create_idaklu_solver<IREEFunctions>,
74-
"Create a iree idaklu solver object",
76+
m.def("create_iree_solver_group", &create_idaklu_solver_group<IREEFunctions>,
77+
"Create a group of iree idaklu solver objects",
7578
py::arg("number_of_states"),
7679
py::arg("number_of_parameters"),
7780
py::arg("rhs_alg"),

src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolver.hpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
#define PYBAMM_IDAKLU_CASADI_SOLVER_HPP
33

44
#include "common.hpp"
5-
#include "Solution.hpp"
5+
#include "SolutionData.hpp"
6+
67

78
/**
89
* Abstract base class for solutions that can use different solvers and vector
@@ -24,14 +25,17 @@ class IDAKLUSolver
2425
~IDAKLUSolver() = default;
2526

2627
/**
27-
* @brief Abstract solver method that returns a Solution class
28+
* @brief Abstract solver method that executes the solver
2829
*/
29-
virtual Solution solve(
30-
np_array t_eval_np,
31-
np_array t_interp_np,
32-
np_array y0_np,
33-
np_array yp0_np,
34-
np_array_dense inputs) = 0;
30+
virtual SolutionData solve(
31+
const std::vector<realtype> &t_eval,
32+
const std::vector<realtype> &t_interp,
33+
const realtype *y0,
34+
const realtype *yp0,
35+
const realtype *inputs,
36+
bool save_adaptive_steps,
37+
bool save_interp_steps
38+
) = 0;
3539

3640
/**
3741
* Abstract method to initialize the solver, once vectors and solver classes
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
#include "IDAKLUSolverGroup.hpp"
2+
#include <omp.h>
3+
#include <optional>
4+
5+
std::vector<Solution> IDAKLUSolverGroup::solve(
6+
np_array t_eval_np,
7+
np_array t_interp_np,
8+
np_array y0_np,
9+
np_array yp0_np,
10+
np_array inputs) {
11+
DEBUG("IDAKLUSolverGroup::solve");
12+
13+
// If t_interp is empty, save all adaptive steps
14+
bool save_adaptive_steps = t_interp_np.size() == 0;
15+
16+
const realtype* t_eval_begin = t_eval_np.data();
17+
const realtype* t_eval_end = t_eval_begin + t_eval_np.size();
18+
const realtype* t_interp_begin = t_interp_np.data();
19+
const realtype* t_interp_end = t_interp_begin + t_interp_np.size();
20+
21+
// Process the time inputs
22+
// 1. Get the sorted and unique t_eval vector
23+
auto const t_eval = makeSortedUnique(t_eval_begin, t_eval_end);
24+
25+
// 2.1. Get the sorted and unique t_interp vector
26+
auto const t_interp_unique_sorted = makeSortedUnique(t_interp_begin, t_interp_end);
27+
28+
// 2.2 Remove the t_eval values from t_interp
29+
auto const t_interp_setdiff = setDiff(t_interp_unique_sorted.begin(), t_interp_unique_sorted.end(), t_eval_begin, t_eval_end);
30+
31+
// 2.3 Finally, get the sorted and unique t_interp vector with t_eval values removed
32+
auto const t_interp = makeSortedUnique(t_interp_setdiff.begin(), t_interp_setdiff.end());
33+
34+
int const number_of_evals = t_eval.size();
35+
int const number_of_interps = t_interp.size();
36+
37+
// setDiff removes entries of t_interp that overlap with
38+
// t_eval, so we need to check if we need to interpolate any unique points.
39+
// This is not the same as save_adaptive_steps since some entries of t_interp
40+
// may be removed by setDiff
41+
bool save_interp_steps = number_of_interps > 0;
42+
43+
// 3. Check if the timestepping entries are valid
44+
if (number_of_evals < 2) {
45+
throw std::invalid_argument(
46+
"t_eval must have at least 2 entries"
47+
);
48+
} else if (save_interp_steps) {
49+
if (t_interp.front() < t_eval.front()) {
50+
throw std::invalid_argument(
51+
"t_interp values must be greater than the smallest t_eval value: "
52+
+ std::to_string(t_eval.front())
53+
);
54+
} else if (t_interp.back() > t_eval.back()) {
55+
throw std::invalid_argument(
56+
"t_interp values must be less than the greatest t_eval value: "
57+
+ std::to_string(t_eval.back())
58+
);
59+
}
60+
}
61+
62+
auto n_coeffs = number_of_states + number_of_parameters * number_of_states;
63+
64+
// check y0 and yp0 and inputs have the correct dimensions
65+
if (y0_np.ndim() != 2)
66+
throw std::domain_error("y0 has wrong number of dimensions. Expected 2 but got " + std::to_string(y0_np.ndim()));
67+
if (yp0_np.ndim() != 2)
68+
throw std::domain_error("yp0 has wrong number of dimensions. Expected 2 but got " + std::to_string(yp0_np.ndim()));
69+
if (inputs.ndim() != 2)
70+
throw std::domain_error("inputs has wrong number of dimensions. Expected 2 but got " + std::to_string(inputs.ndim()));
71+
72+
auto number_of_groups = y0_np.shape()[0];
73+
74+
// check y0 and yp0 and inputs have the correct shape
75+
if (y0_np.shape()[1] != n_coeffs)
76+
throw std::domain_error(
77+
"y0 has wrong number of cols. Expected " + std::to_string(n_coeffs) +
78+
" but got " + std::to_string(y0_np.shape()[1]));
79+
80+
if (yp0_np.shape()[1] != n_coeffs)
81+
throw std::domain_error(
82+
"yp0 has wrong number of cols. Expected " + std::to_string(n_coeffs) +
83+
" but got " + std::to_string(yp0_np.shape()[1]));
84+
85+
if (yp0_np.shape()[0] != number_of_groups)
86+
throw std::domain_error(
87+
"yp0 has wrong number of rows. Expected " + std::to_string(number_of_groups) +
88+
" but got " + std::to_string(yp0_np.shape()[0]));
89+
90+
if (inputs.shape()[0] != number_of_groups)
91+
throw std::domain_error(
92+
"inputs has wrong number of rows. Expected " + std::to_string(number_of_groups) +
93+
" but got " + std::to_string(inputs.shape()[0]));
94+
95+
const std::size_t solves_per_thread = number_of_groups / m_solvers.size();
96+
const std::size_t remainder_solves = number_of_groups % m_solvers.size();
97+
98+
const realtype *y0 = y0_np.data();
99+
const realtype *yp0 = yp0_np.data();
100+
const realtype *inputs_data = inputs.data();
101+
102+
std::vector<SolutionData> results(number_of_groups);
103+
104+
std::optional<std::exception> exception;
105+
106+
omp_set_num_threads(m_solvers.size());
107+
#pragma omp parallel for
108+
for (int i = 0; i < m_solvers.size(); i++) {
109+
try {
110+
for (int j = 0; j < solves_per_thread; j++) {
111+
const std::size_t index = i * solves_per_thread + j;
112+
const realtype *y = y0 + index * y0_np.shape(1);
113+
const realtype *yp = yp0 + index * yp0_np.shape(1);
114+
const realtype *input = inputs_data + index * inputs.shape(1);
115+
results[index] = m_solvers[i]->solve(t_eval, t_interp, y, yp, input, save_adaptive_steps, save_interp_steps);
116+
}
117+
} catch (std::exception &e) {
118+
// If an exception is thrown, we need to catch it and rethrow it outside the parallel region
119+
#pragma omp critical
120+
{
121+
exception = e;
122+
}
123+
}
124+
}
125+
126+
if (exception.has_value()) {
127+
py::set_error(PyExc_ValueError, exception->what());
128+
throw py::error_already_set();
129+
}
130+
131+
for (int i = 0; i < remainder_solves; i++) {
132+
const std::size_t index = number_of_groups - remainder_solves + i;
133+
const realtype *y = y0 + index * y0_np.shape(1);
134+
const realtype *yp = yp0 + index * yp0_np.shape(1);
135+
const realtype *input = inputs_data + index * inputs.shape(1);
136+
results[index] = m_solvers[i]->solve(t_eval, t_interp, y, yp, input, save_adaptive_steps, save_interp_steps);
137+
}
138+
139+
// create solutions (needs to be serial as we're using the Python GIL)
140+
std::vector<Solution> solutions(number_of_groups);
141+
for (int i = 0; i < number_of_groups; i++) {
142+
solutions[i] = results[i].generate_solution();
143+
}
144+
return solutions;
145+
}

0 commit comments

Comments
 (0)