Skip to content

Commit

Permalink
#1996 working on casadi vertcat
Browse files Browse the repository at this point in the history
  • Loading branch information
valentinsulzer committed Jul 28, 2022
1 parent 4c42312 commit d98ea9b
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 16 deletions.
6 changes: 3 additions & 3 deletions examples/scripts/compare_lithium_ion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

# load models
models = [
pybamm.lithium_ion.SPM(),
pybamm.lithium_ion.SPMe(),
# pybamm.lithium_ion.SPM(),
# pybamm.lithium_ion.SPMe(),
pybamm.lithium_ion.DFN(),
pybamm.lithium_ion.NewmanTobias(),
# pybamm.lithium_ion.NewmanTobias(),
]

# create and run simulations
Expand Down
2 changes: 2 additions & 0 deletions pybamm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,12 +209,14 @@
from .solvers.scikits_dae_solver import ScikitsDaeSolver
from .solvers.scikits_ode_solver import ScikitsOdeSolver, have_scikits_odes
from .solvers.scipy_solver import ScipySolver
from .solvers.solver_utils import NoMemAllocVertcat

from .solvers.jax_solver import JaxSolver
from .solvers.jax_bdf_solver import jax_bdf_integrate

from .solvers.idaklu_solver import IDAKLUSolver, have_idaklu


#
# Experiments
#
Expand Down
17 changes: 5 additions & 12 deletions pybamm/solvers/casadi_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,17 +699,10 @@ def _run_integrator(
x0=y0_diff, z0=y0_alg, p=inputs_with_tmin, **self.extra_options_call
)
integration_time = timer.time()
y_sol = casadi.vertcat(casadi_sol["xf"], casadi_sol["zf"])
sol = pybamm.Solution(
t_eval,
y_sol,
model,
inputs_dict,
sensitivities=extract_sensitivities_in_solution,
check_solution=False,
)
sol.integration_time = integration_time
return sol
if casadi_sol["zf"].is_empty():
y_sol = casadi_sol["xf"]
else:
y_sol = pybamm.NoMemAllocVertcat(casadi_sol["xf"], casadi_sol["zf"])
else:
# Repeated calls to the integrator
x = y0_diff
Expand All @@ -733,7 +726,7 @@ def _run_integrator(
if z.is_empty():
y_sol = y_diff
else:
y_sol = casadi.vertcat(y_diff, y_alg)
y_sol = pybamm.NoMemAllocVertcat(y_diff, y_alg)

sol = pybamm.Solution(
t_eval,
Expand Down
61 changes: 61 additions & 0 deletions pybamm/solvers/solver_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#
# Utility functions and classes for solvers
#


class NoMemAllocVertcat:
"""
Acts like a vertcat, but does not allocate new memory.
"""

def __init__(self, a, b):
arrays = [a, b]
self.arrays = arrays

for array in arrays:
if not 1 <= len(array.shape) <= 2:
raise ValueError("Only 1D or 2D arrays are supported")
self._ndim = len(array.shape)

self.len_a = a.shape[0]
shape0 = a.shape[0] + b.shape[0]

if self._ndim == 1:
self._shape = (shape0,)
self._size = shape0
else:
if a.shape[1] != b.shape[1]:
raise ValueError("All arrays must have the same number of columns")
shape1 = a.shape[1]

self._shape = (shape0, shape1)
self._size = shape0 * shape1

@property
def shape(self):
return self._shape

@property
def size(self):
return self._size

@property
def ndim(self):
return self._ndim

def __getitem__(self, key):
if self._ndim == 1 or isinstance(key, int):
if key < self.len_a:
return self.arrays[0][key]
else:
return self.arrays[1][key - self.len_a]

if key[0] == slice(None):
return NoMemAllocVertcat(*[arr[:, key[1]] for arr in self.arrays])
elif isinstance(key[0], int):
if key[0] < self.len_a:
return self.arrays[0][key[0], key[1]]
else:
return self.arrays[1][key[0] - self.len_a, key[1]]
else:
raise NotImplementedError
2 changes: 1 addition & 1 deletion tests/integration/test_models/standard_model_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_solving(
if Crate == 0:
Crate = 1
if t_eval is None:
t_eval = np.linspace(0, 3600 / Crate, 100)
t_eval = np.linspace(0, 3600 / Crate, 1000)

self.solution = self.solver.solve(
self.model,
Expand Down
51 changes: 51 additions & 0 deletions tests/unit/test_solvers/test_solver_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#
# Tests for the solver utility functions and classes
#
import json
import pybamm
import unittest
import numpy as np
import pandas as pd
from scipy.io import loadmat
from tests import get_discretisation_for_testing


class TestSolverUtils(unittest.TestCase):
def test_compare_numpy_vertcat(self):
a0 = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
a1 = np.array([[1, 2, 3]])
b0 = np.array([[13, 14, 15], [16, 17, 18]])

for a, b in zip([a0, b0], [a1, b0]):
pybamm_vertcat = pybamm.NoMemAllocVertcat(a, b)
np_vertcat = np.concatenate((a, b), axis=0)
self.assertEqual(pybamm_vertcat.shape, np_vertcat.shape)
self.assertEqual(pybamm_vertcat.size, np_vertcat.size)
for i in range(pybamm_vertcat.shape[0]):
for j in range(pybamm_vertcat.shape[1]):
self.assertEqual(pybamm_vertcat[i, j], np_vertcat[i, j])
self.assertEqual(pybamm_vertcat[:, j][i], np_vertcat[:, j][i])
for i in range(pybamm_vertcat.shape[0]):
np.testing.assert_array_equal(pybamm_vertcat[i, :], np_vertcat[i, :])

def test_errors(self):
a = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
b = np.ones((4, 5, 6))
with self.assertRaisesRegex(ValueError, "Only 1D or 2D arrays are supported"):
pybamm.NoMemAllocVertcat(a, b)

b = np.array([[10, 11], [13, 14]])
with self.assertRaisesRegex(
ValueError, "All arrays must have the same number of columns"
):
pybamm.NoMemAllocVertcat(a, b)


if __name__ == "__main__":
print("Add -v for more debug output")
import sys

if "-v" in sys.argv:
debug = True
pybamm.settings.debug_mode = True
unittest.main()

0 comments on commit d98ea9b

Please sign in to comment.