diff --git a/examples/scripts/compare_lithium_ion.py b/examples/scripts/compare_lithium_ion.py index 6108036b9b..dd2f5c646e 100644 --- a/examples/scripts/compare_lithium_ion.py +++ b/examples/scripts/compare_lithium_ion.py @@ -7,10 +7,10 @@ # load models models = [ - pybamm.lithium_ion.SPM(), - pybamm.lithium_ion.SPMe(), + # pybamm.lithium_ion.SPM(), + # pybamm.lithium_ion.SPMe(), pybamm.lithium_ion.DFN(), - pybamm.lithium_ion.NewmanTobias(), + # pybamm.lithium_ion.NewmanTobias(), ] # create and run simulations diff --git a/pybamm/__init__.py b/pybamm/__init__.py index dd8e10768b..5b03691db4 100644 --- a/pybamm/__init__.py +++ b/pybamm/__init__.py @@ -208,12 +208,14 @@ from .solvers.scikits_dae_solver import ScikitsDaeSolver from .solvers.scikits_ode_solver import ScikitsOdeSolver, have_scikits_odes from .solvers.scipy_solver import ScipySolver +from .solvers.solver_utils import NoMemAllocVertcat from .solvers.jax_solver import JaxSolver from .solvers.jax_bdf_solver import jax_bdf_integrate from .solvers.idaklu_solver import IDAKLUSolver, have_idaklu + # # Experiments # diff --git a/pybamm/solvers/base_solver.py b/pybamm/solvers/base_solver.py index 9a0e94be03..178abbf91a 100644 --- a/pybamm/solvers/base_solver.py +++ b/pybamm/solvers/base_solver.py @@ -1168,7 +1168,7 @@ def step( if not isinstance(old_solution, pybamm.EmptySolution): if old_solution.all_models[-1] == model: # initialize with old solution - model.y0 = old_solution.all_ys[-1][:, -1] + model.y0 = old_solution.y_last else: _, concatenated_initial_conditions = model.set_initial_conditions_from( old_solution, return_type="ics" diff --git a/pybamm/solvers/casadi_solver.py b/pybamm/solvers/casadi_solver.py index d70569fbb3..3007599fb3 100644 --- a/pybamm/solvers/casadi_solver.py +++ b/pybamm/solvers/casadi_solver.py @@ -250,7 +250,7 @@ def _integrate(self, model, t_eval, inputs_dict=None): # update time t = t_window[-1] # update y0 - y0 = solution.all_ys[-1][:, -1] + y0 = solution.y_last # now we extract sensitivities from the solution if bool(model.calculate_sensitivities): @@ -275,14 +275,16 @@ def _solve_for_event(self, coarse_solution): inputs = casadi.vertcat(*[x for x in inputs_dict.values()]) def find_t_event(sol, typ): + sol_t = sol.all_ts[0] + sol_y = sol.all_ys[0] # Check most recent y to see if any events have been crossed if model.terminate_events_eval: - y_last = sol.all_ys[-1][:, -1] + y_last = sol.y_last crossed_events = np.sign( np.concatenate( [ - event(sol.t[-1], y_last, inputs) + event(sol_t[-1], y_last, inputs) for event in model.terminate_events_eval ] ) @@ -315,13 +317,16 @@ def f(idx): # We take away 1e-5 to deal with the case where the event sits # exactly on zero, as can happen when the event switch is used # (fast with events mode) - f_eval[idx] = event(sol.t[idx], sol.y[:, idx], inputs) - 1e-5 + f_eval[idx] = ( + init_event_sign * event(sol_t[idx], sol_y[:, idx], inputs) + - 1e-5 + ) return f_eval[idx] def integer_bisect(): a_n = 0 - b_n = len(sol.t) - 1 - for _ in range(len(sol.t)): + b_n = len(sol_t) - 1 + for _ in range(len(sol_t)): if a_n + 1 == b_n: return a_n m_n = (a_n + b_n) // 2 @@ -343,8 +348,8 @@ def integer_bisect(): # Linear interpolation between the two indices to find the root time # We could do cubic interpolation here instead but it would be # slower - t_lower = sol.t[event_idx_lower] - t_upper = sol.t[event_idx_lower + 1] + t_lower = sol_t[event_idx_lower] + t_upper = sol_t[event_idx_lower + 1] event_lower = abs(f(event_idx_lower)) event_upper = abs(f(event_idx_lower + 1)) @@ -360,7 +365,11 @@ def integer_bisect(): t_event = np.nanmin(t_events) # create interpolant to evaluate y in the current integration # window - y_sol = interp1d(sol.t, sol.y, kind="linear") + sol_event_idx = np.where(t_event < sol_t)[0][0] + sol_t_interp = sol_t[sol_event_idx - 1 : sol_event_idx + 1] + sol_y_interp = sol_y[:, sol_event_idx - 1 : sol_event_idx + 1] + sol_y_interp = casadi.horzcat(sol_y_interp[:, 0], sol_y_interp[:, 1]) + y_sol = interp1d(sol_t_interp, sol_y_interp, kind="linear") y_event = y_sol(t_event) closest_event_idx = event_idx[np.nanargmin(t_events)] @@ -393,7 +402,7 @@ def integer_bisect(): self.create_integrator(model, inputs, t_window_event_dense) use_grid = True - y0 = coarse_solution.y[:, event_idx_lower] + y0 = coarse_solution.all_ys[0][:, event_idx_lower] dense_step_sol = self._run_integrator( model, y0, @@ -418,8 +427,8 @@ def integer_bisect(): # Return solution truncated at the first coarse event time # Also assign t_event - t_sol = coarse_solution.t[: event_idx_lower + 1] - y_sol = coarse_solution.y[:, : event_idx_lower + 1] + t_sol = coarse_solution.all_ts[0][: event_idx_lower + 1] + y_sol = coarse_solution.all_ys[0][:, : event_idx_lower + 1] solution = pybamm.Solution( t_sol, y_sol, @@ -627,7 +636,9 @@ def _run_integrator( else: integrator = self.integrators[model]["no grid"] - len_rhs = model.concatenated_rhs.size + len_rhs = model.len_rhs + len_alg = model.len_alg + len_t = len(t_eval) # Check y0 to see if it includes sensitivities if explicit_sensitivities: @@ -652,24 +663,19 @@ def _run_integrator( ) pybamm.logger.debug("Finished casadi integrator") integration_time = timer.time() - y_sol = casadi.vertcat(casadi_sol["xf"], casadi_sol["zf"]) - sol = pybamm.Solution( - t_eval, - y_sol, - model, - inputs_dict, - sensitivities=extract_sensitivities_in_solution, - check_solution=False, - ) - sol.integration_time = integration_time - return sol + x_sols = casadi.horzsplit(casadi_sol["xf"]) + if casadi_sol["zf"].is_empty(): + y_sols = x_sols + else: + z_sols = casadi.horzsplit(casadi_sol["zf"]) + y_sols = pybamm.NoMemAllocVertcat(x_sols, z_sols, len_rhs, len_alg) else: # Repeated calls to the integrator x = y0_diff z = y0_alg - y_diff = x - y_alg = z - for i in range(len(t_eval) - 1): + x_sols = [x] + z_sols = [z] + for i in range(len_t - 1): t_min = t_eval[i] t_max = t_eval[i + 1] inputs_with_tlims = casadi.vertcat(inputs, t_min, t_max) @@ -680,17 +686,17 @@ def _run_integrator( integration_time = timer.time() x = casadi_sol["xf"] z = casadi_sol["zf"] - y_diff = casadi.horzcat(y_diff, x) + x_sols.append(x) if not z.is_empty(): - y_alg = casadi.horzcat(y_alg, z) + z_sols.append(z) if z.is_empty(): - y_sol = y_diff + y_sols = x_sols else: - y_sol = casadi.vertcat(y_diff, y_alg) + y_sols = pybamm.NoMemAllocVertcat(x_sols, z_sols, len_rhs, len_alg) sol = pybamm.Solution( t_eval, - y_sol, + y_sols, model, inputs_dict, sensitivities=extract_sensitivities_in_solution, diff --git a/pybamm/solvers/processed_variable.py b/pybamm/solvers/processed_variable.py index 968e95e1a6..18a79d27ba 100644 --- a/pybamm/solvers/processed_variable.py +++ b/pybamm/solvers/processed_variable.py @@ -39,9 +39,14 @@ def __init__(self, base_variables, base_variables_casadi, solution, warn=True): self.all_ts = solution.all_ts self.all_ys = solution.all_ys + self.y_last = solution.y_last self.all_inputs = solution.all_inputs self.all_inputs_casadi = solution.all_inputs_casadi + # utilities for no_mem_alloc + self.uses_no_mem_alloc = isinstance(self.all_ys[0], pybamm.NoMemAllocVertcat) + self.out = casadi.DM.zeros(self.all_ys[0].shape) + self.mesh = base_variables[0].mesh self.domain = base_variables[0].domain self.domains = base_variables[0].domains @@ -58,10 +63,10 @@ def __init__(self, base_variables, base_variables_casadi, solution, warn=True): # Store length scales self.length_scales = solution.length_scales_eval - # Evaluate base variable at initial time + # Evaluate base variable at final time self.base_eval = self.base_variables_casadi[0]( - self.all_ts[0][0], self.all_ys[0][:, 0], self.all_inputs_casadi[0] - ).full() + self.all_ts[-1][-1], self.y_last, self.all_inputs_casadi[-1] + ).toarray() # handle 2D (in space) finite element variables differently if ( @@ -113,7 +118,7 @@ def initialise_0D(self): for inner_idx, t in enumerate(ts): t = ts[inner_idx] y = ys[:, inner_idx] - entries[idx] = base_var_casadi(t, y, inputs).full()[0, 0] + entries[idx] = float(base_var_casadi(t, y, inputs)) idx += 1 # set up interpolation @@ -145,7 +150,7 @@ def initialise_1D(self, fixed_t=False): for inner_idx, t in enumerate(ts): t = ts[inner_idx] y = ys[:, inner_idx] - entries[:, idx] = base_var_casadi(t, y, inputs).full()[:, 0] + entries[:, idx] = np.array(base_var_casadi(t, y, inputs).elements()) idx += 1 # Get node and edge values @@ -244,7 +249,7 @@ def initialise_2D(self): t = ts[inner_idx] y = ys[:, inner_idx] entries[:, :, idx] = np.reshape( - base_var_casadi(t, y, inputs).full(), + np.array(base_var_casadi(t, y, inputs).elements()), [first_dim_size, second_dim_size], order="F", ) @@ -395,7 +400,7 @@ def initialise_2D_scikit_fem(self): t = ts[inner_idx] y = ys[:, inner_idx] entries[:, :, idx] = np.reshape( - base_var_casadi(t, y, inputs).full(), + base_var_casadi(t, y, inputs).toarray(), [len_y, len_z], order="C", ) @@ -556,11 +561,11 @@ def initialise_sensitivity_explicit_forward(self): "dvar_dp", [t_casadi, y_casadi, p_casadi_stacked], [dvar_dp] ) for index, (ts, ys) in enumerate(zip(self.all_ts, self.all_ys)): - for idx, t in enumerate(ts): - u = ys[:, idx] + for inner_idx, t in enumerate(ts): + u = ys[:, inner_idx] next_dvar_dy_eval = dvar_dy_func(t, u, inputs_stacked) next_dvar_dp_eval = dvar_dp_func(t, u, inputs_stacked) - if index == 0 and idx == 0: + if index == 0 and inner_idx == 0: dvar_dy_eval = next_dvar_dy_eval dvar_dp_eval = next_dvar_dp_eval else: diff --git a/pybamm/solvers/solution.py b/pybamm/solvers/solution.py index 3b3b0f17be..d530fd0856 100644 --- a/pybamm/solvers/solution.py +++ b/pybamm/solvers/solution.py @@ -81,6 +81,7 @@ def __init__( all_ys = [all_ys] if not isinstance(all_models, list): all_models = [all_models] + self._all_ts = all_ts self._all_ys = all_ys self._all_ys_and_sens = all_ys @@ -317,8 +318,7 @@ def check_ys_are_not_too_large(self): # Only check last one so that it doesn't take too long # We only care about the cases where y is growing too large without any # restraint, so if y gets large in the middle then comes back down that is ok - y, model = self.all_ys[-1], self.all_models[-1] - y = y[:, -1] + y, model = self.y_last, self.all_models[-1] if np.any(y > pybamm.settings.max_y_value): for var in [*model.rhs.keys(), *model.algebraic.keys()]: y_var = y[model.variables[var.name].y_slices[0]] @@ -354,6 +354,20 @@ def all_inputs_casadi(self): ] return self._all_inputs_casadi + @property + def y_last(self): + try: + return self._y_last + except AttributeError: + all_ys_last = self.all_ys[-1] + if isinstance(all_ys_last, pybamm.NoMemAllocVertcat): + self._y_last = all_ys_last[:, -1] + elif all_ys_last.shape[1] == 1: + self._y_last = all_ys_last + else: + self._y_last = all_ys_last[:, -1] + return self._y_last + @property def t_event(self): """Time at which the event happens""" @@ -705,16 +719,19 @@ def __add__(self, other): else: all_ts = self.all_ts + other.all_ts all_ys = self.all_ys + other.all_ys + all_models = self.all_models + other.all_models + all_inputs = self.all_inputs + other.all_inputs new_sol = Solution( all_ts, all_ys, - self.all_models + other.all_models, - self.all_inputs + other.all_inputs, + all_models, + all_inputs, other.t_event, other.y_event, other.termination, bool(self.sensitivities), + check_solution=False, ) new_sol.closest_event_idx = other.closest_event_idx @@ -741,6 +758,7 @@ def copy(self): self.t_event, self.y_event, self.termination, + check_solution=False, ) new_sol._all_inputs_casadi = self.all_inputs_casadi new_sol._sub_solutions = self.sub_solutions diff --git a/pybamm/solvers/solver_utils.py b/pybamm/solvers/solver_utils.py new file mode 100644 index 0000000000..5bc711b051 --- /dev/null +++ b/pybamm/solvers/solver_utils.py @@ -0,0 +1,41 @@ +# +# Utility functions and classes for solvers +# +import casadi + + +class NoMemAllocVertcat: + """ + Acts like a vertcat, but does not allocate new memory. + """ + + def __init__(self, xs, zs, len_x=None, len_z=None, items=None): + self.xs = xs + self.zs = zs + self.len_x = len_x or xs[0].shape[0] + self.len_z = len_z or zs[0].shape[0] + len_items = len(xs) + self.shape = (self.len_x + self.len_z, len_items) + + if items is None: + items = [None] * len_items + for idx in range(len_items): + out = casadi.DM.zeros((self.shape[0], 1)) + out[: self.len_x] = self.xs[idx] + out[self.len_x :] = self.zs[idx] + items[idx] = out + + self.items = items + + def __getitem__(self, idx): + if idx[0] != slice(None): + raise NotImplementedError( + "Only full slices are supported in the first entry of the index" + ) + idx = idx[1] + if isinstance(idx, slice): + return NoMemAllocVertcat( + self.xs[idx], self.zs[idx], self.len_x, self.len_z, self.items[idx] + ) + else: + return self.items[idx] diff --git a/tests/integration/test_models/standard_model_tests.py b/tests/integration/test_models/standard_model_tests.py index 17cac23bb7..0bff62778a 100644 --- a/tests/integration/test_models/standard_model_tests.py +++ b/tests/integration/test_models/standard_model_tests.py @@ -80,7 +80,7 @@ def test_solving( if Crate == 0: Crate = 1 if t_eval is None: - t_eval = np.linspace(0, 3600 / Crate, 100) + t_eval = np.linspace(0, 3600 / Crate, 1000) self.solution = self.solver.solve( self.model, diff --git a/tests/unit/test_solvers/test_solver_utils.py b/tests/unit/test_solvers/test_solver_utils.py new file mode 100644 index 0000000000..3c557dfaed --- /dev/null +++ b/tests/unit/test_solvers/test_solver_utils.py @@ -0,0 +1,44 @@ +# +# Tests for the solver utility functions and classes +# +import pybamm +import unittest +import numpy as np + + +class TestSolverUtils(unittest.TestCase): + def test_compare_numpy_vertcat(self): + x = [np.array([1, 2]), np.array([4, 5]), np.array([7, 8])] + z = [np.array([13, 14, 15]), np.array([16, 17, 18]), np.array([19, 20, 21])] + + pybamm_vertcat = pybamm.NoMemAllocVertcat(x, z) + + self.assertEqual(pybamm_vertcat.shape, (5, 3)) + np.testing.assert_array_equal( + pybamm_vertcat[:, 0], np.array([1, 2, 13, 14, 15])[:, np.newaxis] + ) + np.testing.assert_array_equal( + pybamm_vertcat[:, -1], np.array([7, 8, 19, 20, 21])[:, np.newaxis] + ) + + pybamm_sub_vertcat = pybamm_vertcat[:, 1:] + self.assertEqual(pybamm_sub_vertcat.shape, (5, 2)) + np.testing.assert_array_equal(pybamm_vertcat[:, 1], pybamm_sub_vertcat[:, 0]) + + def test_errors(self): + x = [np.array([1, 2, 3]), np.array([4, 5, 6]), np.array([7, 8, 9])] + z = [np.array([13, 14, 15]), np.array([16, 17, 18]), np.array([19, 20, 21])] + + pybamm_vertcat = pybamm.NoMemAllocVertcat(x, z) + with self.assertRaisesRegex(NotImplementedError, "Only full slices"): + pybamm_vertcat[0, 0] + + +if __name__ == "__main__": + print("Add -v for more debug output") + import sys + + if "-v" in sys.argv: + debug = True + pybamm.settings.debug_mode = True + unittest.main()