Skip to content

Commit

Permalink
Merge pull request pybamm-team#3217 from jsbrittain/solver
Browse files Browse the repository at this point in the history
Idaklu solver can be given a list of variables to calculate during the solve
  • Loading branch information
martinjrobins authored Sep 29, 2023
2 parents fc570af + c7d109b commit d238192
Show file tree
Hide file tree
Showing 29 changed files with 3,211 additions and 1,188 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@

## Features

- Idaklu solver can be given a list of variables to calculate during the solve ([#3217](https://github.com/pybamm-team/PyBaMM/pull/3217))
- Enable multithreading in IDAKLU solver ([#2947](https://github.com/pybamm-team/PyBaMM/pull/2947))
- If a solution contains cycles and steps, the cycle number and step number are now saved when `solution.save_data()` is called ([#2931](https://github.com/pybamm-team/PyBaMM/pull/2931))
- Experiments can now be given a `start_time` to define when each step should be triggered ([#2616](https://github.com/pybamm-team/PyBaMM/pull/2616))
Expand Down
8 changes: 7 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,14 @@ pybind11_add_module(idaklu
pybamm/solvers/c_solvers/idaklu/casadi_functions.hpp
pybamm/solvers/c_solvers/idaklu/casadi_solver.cpp
pybamm/solvers/c_solvers/idaklu/casadi_solver.hpp
pybamm/solvers/c_solvers/idaklu/casadi_sundials_functions.hpp
pybamm/solvers/c_solvers/idaklu/CasadiSolver.cpp
pybamm/solvers/c_solvers/idaklu/CasadiSolver.hpp
pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP.cpp
pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP.hpp
pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP_solvers.cpp
pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP_solvers.hpp
pybamm/solvers/c_solvers/idaklu/casadi_sundials_functions.cpp
pybamm/solvers/c_solvers/idaklu/casadi_sundials_functions.hpp
pybamm/solvers/c_solvers/idaklu/common.hpp
pybamm/solvers/c_solvers/idaklu/python.hpp
pybamm/solvers/c_solvers/idaklu/python.cpp
Expand Down
1 change: 1 addition & 0 deletions pybamm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@
#
from .solvers.solution import Solution, EmptySolution, make_cycle_solution
from .solvers.processed_variable import ProcessedVariable
from .solvers.processed_variable_computed import ProcessedVariableComputed
from .solvers.base_solver import BaseSolver
from .solvers.dummy_solver import DummySolver
from .solvers.algebraic_solver import AlgebraicSolver
Expand Down
94 changes: 82 additions & 12 deletions pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ class BaseSolver(object):
The tolerance for the initial-condition solver (default is 1e-6).
extrap_tol : float, optional
The tolerance to assert whether extrapolation occurs or not. Default is 0.
output_variables : list[str], optional
List of variables to calculate and return. If none are specified then
the complete state vector is returned (can be very large) (default is [])
"""

def __init__(
Expand All @@ -48,20 +51,23 @@ def __init__(
root_method=None,
root_tol=1e-6,
extrap_tol=None,
output_variables=[],
):
self.method = method
self.rtol = rtol
self.atol = atol
self.root_tol = root_tol
self.root_method = root_method
self.extrap_tol = extrap_tol or -1e-10
self.output_variables = output_variables
self._model_set_up = {}

# Defaults, can be overwritten by specific solver
self.name = "Base solver"
self.ode_solver = False
self.algebraic_solver = False
self._on_extrapolation = "warn"
self.computed_var_fcns = {}

@property
def root_method(self):
Expand Down Expand Up @@ -250,8 +256,57 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False):
model.casadi_sensitivities_rhs = jacp_rhs
model.casadi_sensitivities_algebraic = jacp_algebraic

# if output_variables specified then convert functions to casadi
# expressions for evaluation within the respective solver
self.computed_var_fcns = {}
self.computed_dvar_dy_fcns = {}
self.computed_dvar_dp_fcns = {}
for key in self.output_variables:
# ExplicitTimeIntegral's are not computed as part of the solver and
# do not need to be converted
if isinstance(
model.variables_and_events[key], pybamm.ExplicitTimeIntegral
):
continue
# Generate Casadi function to calculate variable and derivates
# to enable sensitivites to be computed within the solver
(
self.computed_var_fcns[key],
self.computed_dvar_dy_fcns[key],
self.computed_dvar_dp_fcns[key],
_,
) = process(
model.variables_and_events[key],
BaseSolver._wrangle_name(key),
vars_for_processing,
use_jacobian=True,
return_jacp_stacked=True,
)

pybamm.logger.info("Finish solver set-up")

@classmethod
def _wrangle_name(cls, name: str) -> str:
"""
Wrangle a function name to replace special characters
"""
replacements = [
(" ", "_"),
("[", ""),
("]", ""),
(".", "_"),
("-", "_"),
("(", ""),
(")", ""),
("%", "prc"),
(",", ""),
(".", ""),
]
name = "v_" + name.casefold()
for string, replacement in replacements:
name = name.replace(string, replacement)
return name

def _check_and_prepare_model_inplace(self, model, inputs, ics_only):
"""
Performs checks on the model and prepares it for solving.
Expand Down Expand Up @@ -1366,7 +1421,9 @@ def _set_up_model_inputs(self, model, inputs):
return ordered_inputs


def process(symbol, name, vars_for_processing, use_jacobian=None):
def process(
symbol, name, vars_for_processing, use_jacobian=None, return_jacp_stacked=None
):
"""
Parameters
----------
Expand All @@ -1376,6 +1433,8 @@ def process(symbol, name, vars_for_processing, use_jacobian=None):
function evaluators created will have this base name
use_jacobian: bool, optional
whether to return Jacobian functions
return_jacp_stacked: bool, optional
returns Jacobian function wrt stacked parameters instead of jacp
Returns
-------
Expand Down Expand Up @@ -1553,17 +1612,28 @@ def jacp(*args, **kwargs):
"CasADi"
)
)
# WARNING, jacp for convert_to_format=casadi does not return a dict
# instead it returns multiple return values, one for each param
# TODO: would it be faster to do the jacobian wrt pS_casadi_stacked?
jacp = casadi.Function(
name + "_jacp",
[t_casadi, y_and_S, p_casadi_stacked],
[
casadi.densify(casadi.jacobian(casadi_expression, p_casadi[pname]))
for pname in model.calculate_sensitivities
],
)
# Compute derivate wrt p-stacked (can be passed to solver to
# compute sensitivities online)
if return_jacp_stacked:
jacp = casadi.Function(
f"d{name}_dp",
[t_casadi, y_casadi, p_casadi_stacked],
[casadi.jacobian(casadi_expression, p_casadi_stacked)],
)
else:
# WARNING, jacp for convert_to_format=casadi does not return a dict
# instead it returns multiple return values, one for each param
# TODO: would it be faster to do the jacobian wrt pS_casadi_stacked?
jacp = casadi.Function(
name + "_jacp",
[t_casadi, y_and_S, p_casadi_stacked],
[
casadi.densify(
casadi.jacobian(casadi_expression, p_casadi[pname])
)
for pname in model.calculate_sensitivities
],
)

if use_jacobian:
report(f"Calculating jacobian for {name} using CasADi")
Expand Down
91 changes: 64 additions & 27 deletions pybamm/solvers/c_solvers/idaklu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <pybind11/functional.h>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/stl_bind.h>

#include <vector>
Expand All @@ -25,39 +26,75 @@ PYBIND11_MODULE(idaklu, m)
py::bind_vector<std::vector<np_array>>(m, "VectorNdArray");

m.def("solve_python", &solve_python,
"The solve function for python evaluators", py::arg("t"), py::arg("y0"),
py::arg("yp0"), py::arg("res"), py::arg("jac"), py::arg("sens"),
py::arg("get_jac_data"), py::arg("get_jac_row_vals"),
py::arg("get_jac_col_ptr"), py::arg("nnz"), py::arg("events"),
py::arg("number_of_events"), py::arg("use_jacobian"),
py::arg("rhs_alg_id"), py::arg("atol"), py::arg("rtol"),
py::arg("inputs"), py::arg("number_of_sensitivity_parameters"),
py::return_value_policy::take_ownership);
"The solve function for python evaluators",
py::arg("t"),
py::arg("y0"),
py::arg("yp0"),
py::arg("res"),
py::arg("jac"),
py::arg("sens"),
py::arg("get_jac_data"),
py::arg("get_jac_row_vals"),
py::arg("get_jac_col_ptr"),
py::arg("nnz"),
py::arg("events"),
py::arg("number_of_events"),
py::arg("use_jacobian"),
py::arg("rhs_alg_id"),
py::arg("atol"),
py::arg("rtol"),
py::arg("inputs"),
py::arg("number_of_sensitivity_parameters"),
py::return_value_policy::take_ownership);

py::class_<CasadiSolver>(m, "CasadiSolver")
.def("solve", &CasadiSolver::solve, "perform a solve", py::arg("t"),
py::arg("y0"), py::arg("yp0"), py::arg("inputs"),
py::return_value_policy::take_ownership);
.def("solve", &CasadiSolver::solve,
"perform a solve",
py::arg("t"),
py::arg("y0"),
py::arg("yp0"),
py::arg("inputs"),
py::return_value_policy::take_ownership);

//py::bind_vector<std::vector<Function>>(m, "VectorFunction");
//py::implicitly_convertible<py::iterable, std::vector<Function>>();

m.def("create_casadi_solver", &create_casadi_solver,
"Create a casadi idaklu solver object", py::arg("number_of_states"),
py::arg("number_of_parameters"), py::arg("rhs_alg"),
py::arg("jac_times_cjmass"), py::arg("jac_times_cjmass_colptrs"),
py::arg("jac_times_cjmass_rowvals"), py::arg("jac_times_cjmass_nnz"),
py::arg("jac_bandwidth_lower"), py::arg("jac_bandwidth_upper"),
py::arg("jac_action"), py::arg("mass_action"), py::arg("sens"),
py::arg("events"), py::arg("number_of_events"), py::arg("rhs_alg_id"),
py::arg("atol"), py::arg("rtol"), py::arg("inputs"), py::arg("options"),
py::return_value_policy::take_ownership);

m.def("generate_function", &generate_function, "Generate a casadi function",
py::arg("string"), py::return_value_policy::take_ownership);
"Create a casadi idaklu solver object",
py::arg("number_of_states"),
py::arg("number_of_parameters"),
py::arg("rhs_alg"),
py::arg("jac_times_cjmass"),
py::arg("jac_times_cjmass_colptrs"),
py::arg("jac_times_cjmass_rowvals"),
py::arg("jac_times_cjmass_nnz"),
py::arg("jac_bandwidth_lower"),
py::arg("jac_bandwidth_upper"),
py::arg("jac_action"),
py::arg("mass_action"),
py::arg("sens"),
py::arg("events"),
py::arg("number_of_events"),
py::arg("rhs_alg_id"),
py::arg("atol"),
py::arg("rtol"),
py::arg("inputs"),
py::arg("var_casadi_fcns"),
py::arg("dvar_dy_fcns"),
py::arg("dvar_dp_fcns"),
py::arg("options"),
py::return_value_policy::take_ownership);

m.def("generate_function", &generate_function,
"Generate a casadi function",
py::arg("string"),
py::return_value_policy::take_ownership);

py::class_<Function>(m, "Function");

py::class_<Solution>(m, "solution")
.def_readwrite("t", &Solution::t)
.def_readwrite("y", &Solution::y)
.def_readwrite("yS", &Solution::yS)
.def_readwrite("flag", &Solution::flag);
.def_readwrite("t", &Solution::t)
.def_readwrite("y", &Solution::y)
.def_readwrite("yS", &Solution::yS)
.def_readwrite("flag", &Solution::flag);
}
1 change: 1 addition & 0 deletions pybamm/solvers/c_solvers/idaklu/CasadiSolver.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#include "CasadiSolver.hpp"
49 changes: 49 additions & 0 deletions pybamm/solvers/c_solvers/idaklu/CasadiSolver.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#ifndef PYBAMM_IDAKLU_CASADI_SOLVER_HPP
#define PYBAMM_IDAKLU_CASADI_SOLVER_HPP

#include <casadi/casadi.hpp>
using Function = casadi::Function;

#include "casadi_functions.hpp"
#include "common.hpp"
#include "options.hpp"
#include "solution.hpp"
#include "sundials_legacy_wrapper.hpp"

/**
* Abstract base class for solutions that can use different solvers and vector
* implementations.
* @brief An abstract base class for the Idaklu solver
*/
class CasadiSolver
{
public:

/**
* @brief Default constructor
*/
CasadiSolver() = default;

/**
* @brief Default destructor
*/
~CasadiSolver() = default;

/**
* @brief Abstract solver method that returns a Solution class
*/
virtual Solution solve(
np_array t_np,
np_array y0_np,
np_array yp0_np,
np_array_dense inputs) = 0;

/**
* Abstract method to initialize the solver, once vectors and solver classes
* are set
* @brief Abstract initialization method
*/
virtual void Initialize() = 0;
};

#endif // PYBAMM_IDAKLU_CASADI_SOLVER_HPP
Loading

0 comments on commit d238192

Please sign in to comment.