Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 1996 vertcat #2208

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/scripts/compare_lithium_ion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

# load models
models = [
pybamm.lithium_ion.SPM(),
pybamm.lithium_ion.SPMe(),
# pybamm.lithium_ion.SPM(),
# pybamm.lithium_ion.SPMe(),
pybamm.lithium_ion.DFN(),
pybamm.lithium_ion.NewmanTobias(),
# pybamm.lithium_ion.NewmanTobias(),
]

# create and run simulations
Expand Down
2 changes: 2 additions & 0 deletions pybamm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,12 +208,14 @@
from .solvers.scikits_dae_solver import ScikitsDaeSolver
from .solvers.scikits_ode_solver import ScikitsOdeSolver, have_scikits_odes
from .solvers.scipy_solver import ScipySolver
from .solvers.solver_utils import NoMemAllocVertcat

from .solvers.jax_solver import JaxSolver
from .solvers.jax_bdf_solver import jax_bdf_integrate

from .solvers.idaklu_solver import IDAKLUSolver, have_idaklu


#
# Experiments
#
Expand Down
2 changes: 1 addition & 1 deletion pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -1168,7 +1168,7 @@ def step(
if not isinstance(old_solution, pybamm.EmptySolution):
if old_solution.all_models[-1] == model:
# initialize with old solution
model.y0 = old_solution.all_ys[-1][:, -1]
model.y0 = old_solution.y_last
else:
_, concatenated_initial_conditions = model.set_initial_conditions_from(
old_solution, return_type="ics"
Expand Down
70 changes: 38 additions & 32 deletions pybamm/solvers/casadi_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def _integrate(self, model, t_eval, inputs_dict=None):
# update time
t = t_window[-1]
# update y0
y0 = solution.all_ys[-1][:, -1]
y0 = solution.y_last

# now we extract sensitivities from the solution
if bool(model.calculate_sensitivities):
Expand All @@ -275,14 +275,16 @@ def _solve_for_event(self, coarse_solution):
inputs = casadi.vertcat(*[x for x in inputs_dict.values()])

def find_t_event(sol, typ):
sol_t = sol.all_ts[0]
sol_y = sol.all_ys[0]

# Check most recent y to see if any events have been crossed
if model.terminate_events_eval:
y_last = sol.all_ys[-1][:, -1]
y_last = sol.y_last
crossed_events = np.sign(
np.concatenate(
[
event(sol.t[-1], y_last, inputs)
event(sol_t[-1], y_last, inputs)
for event in model.terminate_events_eval
]
)
Expand Down Expand Up @@ -315,13 +317,16 @@ def f(idx):
# We take away 1e-5 to deal with the case where the event sits
# exactly on zero, as can happen when the event switch is used
# (fast with events mode)
f_eval[idx] = event(sol.t[idx], sol.y[:, idx], inputs) - 1e-5
f_eval[idx] = (
init_event_sign * event(sol_t[idx], sol_y[:, idx], inputs)
- 1e-5
)
return f_eval[idx]

def integer_bisect():
a_n = 0
b_n = len(sol.t) - 1
for _ in range(len(sol.t)):
b_n = len(sol_t) - 1
for _ in range(len(sol_t)):
if a_n + 1 == b_n:
return a_n
m_n = (a_n + b_n) // 2
Expand All @@ -343,8 +348,8 @@ def integer_bisect():
# Linear interpolation between the two indices to find the root time
# We could do cubic interpolation here instead but it would be
# slower
t_lower = sol.t[event_idx_lower]
t_upper = sol.t[event_idx_lower + 1]
t_lower = sol_t[event_idx_lower]
t_upper = sol_t[event_idx_lower + 1]
event_lower = abs(f(event_idx_lower))
event_upper = abs(f(event_idx_lower + 1))

Expand All @@ -360,7 +365,11 @@ def integer_bisect():
t_event = np.nanmin(t_events)
# create interpolant to evaluate y in the current integration
# window
y_sol = interp1d(sol.t, sol.y, kind="linear")
sol_event_idx = np.where(t_event < sol_t)[0][0]
sol_t_interp = sol_t[sol_event_idx - 1 : sol_event_idx + 1]
sol_y_interp = sol_y[:, sol_event_idx - 1 : sol_event_idx + 1]
sol_y_interp = casadi.horzcat(sol_y_interp[:, 0], sol_y_interp[:, 1])
y_sol = interp1d(sol_t_interp, sol_y_interp, kind="linear")
y_event = y_sol(t_event)

closest_event_idx = event_idx[np.nanargmin(t_events)]
Expand Down Expand Up @@ -393,7 +402,7 @@ def integer_bisect():
self.create_integrator(model, inputs, t_window_event_dense)
use_grid = True

y0 = coarse_solution.y[:, event_idx_lower]
y0 = coarse_solution.all_ys[0][:, event_idx_lower]
dense_step_sol = self._run_integrator(
model,
y0,
Expand All @@ -418,8 +427,8 @@ def integer_bisect():

# Return solution truncated at the first coarse event time
# Also assign t_event
t_sol = coarse_solution.t[: event_idx_lower + 1]
y_sol = coarse_solution.y[:, : event_idx_lower + 1]
t_sol = coarse_solution.all_ts[0][: event_idx_lower + 1]
y_sol = coarse_solution.all_ys[0][:, : event_idx_lower + 1]
solution = pybamm.Solution(
t_sol,
y_sol,
Expand Down Expand Up @@ -627,7 +636,9 @@ def _run_integrator(
else:
integrator = self.integrators[model]["no grid"]

len_rhs = model.concatenated_rhs.size
len_rhs = model.len_rhs
len_alg = model.len_alg
len_t = len(t_eval)

# Check y0 to see if it includes sensitivities
if explicit_sensitivities:
Expand All @@ -652,24 +663,19 @@ def _run_integrator(
)
pybamm.logger.debug("Finished casadi integrator")
integration_time = timer.time()
y_sol = casadi.vertcat(casadi_sol["xf"], casadi_sol["zf"])
sol = pybamm.Solution(
t_eval,
y_sol,
model,
inputs_dict,
sensitivities=extract_sensitivities_in_solution,
check_solution=False,
)
sol.integration_time = integration_time
return sol
x_sols = casadi.horzsplit(casadi_sol["xf"])
if casadi_sol["zf"].is_empty():
y_sols = x_sols
else:
z_sols = casadi.horzsplit(casadi_sol["zf"])
y_sols = pybamm.NoMemAllocVertcat(x_sols, z_sols, len_rhs, len_alg)
else:
# Repeated calls to the integrator
x = y0_diff
z = y0_alg
y_diff = x
y_alg = z
for i in range(len(t_eval) - 1):
x_sols = [x]
z_sols = [z]
for i in range(len_t - 1):
t_min = t_eval[i]
t_max = t_eval[i + 1]
inputs_with_tlims = casadi.vertcat(inputs, t_min, t_max)
Expand All @@ -680,17 +686,17 @@ def _run_integrator(
integration_time = timer.time()
x = casadi_sol["xf"]
z = casadi_sol["zf"]
y_diff = casadi.horzcat(y_diff, x)
x_sols.append(x)
if not z.is_empty():
y_alg = casadi.horzcat(y_alg, z)
z_sols.append(z)
if z.is_empty():
y_sol = y_diff
y_sols = x_sols
else:
y_sol = casadi.vertcat(y_diff, y_alg)
y_sols = pybamm.NoMemAllocVertcat(x_sols, z_sols, len_rhs, len_alg)

sol = pybamm.Solution(
t_eval,
y_sol,
y_sols,
model,
inputs_dict,
sensitivities=extract_sensitivities_in_solution,
Expand Down
25 changes: 15 additions & 10 deletions pybamm/solvers/processed_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,14 @@ def __init__(self, base_variables, base_variables_casadi, solution, warn=True):

self.all_ts = solution.all_ts
self.all_ys = solution.all_ys
self.y_last = solution.y_last
self.all_inputs = solution.all_inputs
self.all_inputs_casadi = solution.all_inputs_casadi

# utilities for no_mem_alloc
self.uses_no_mem_alloc = isinstance(self.all_ys[0], pybamm.NoMemAllocVertcat)
self.out = casadi.DM.zeros(self.all_ys[0].shape)

self.mesh = base_variables[0].mesh
self.domain = base_variables[0].domain
self.domains = base_variables[0].domains
Expand All @@ -58,10 +63,10 @@ def __init__(self, base_variables, base_variables_casadi, solution, warn=True):
# Store length scales
self.length_scales = solution.length_scales_eval

# Evaluate base variable at initial time
# Evaluate base variable at final time
self.base_eval = self.base_variables_casadi[0](
self.all_ts[0][0], self.all_ys[0][:, 0], self.all_inputs_casadi[0]
).full()
self.all_ts[-1][-1], self.y_last, self.all_inputs_casadi[-1]
).toarray()

# handle 2D (in space) finite element variables differently
if (
Expand Down Expand Up @@ -113,7 +118,7 @@ def initialise_0D(self):
for inner_idx, t in enumerate(ts):
t = ts[inner_idx]
y = ys[:, inner_idx]
entries[idx] = base_var_casadi(t, y, inputs).full()[0, 0]
entries[idx] = float(base_var_casadi(t, y, inputs))
idx += 1

# set up interpolation
Expand Down Expand Up @@ -145,7 +150,7 @@ def initialise_1D(self, fixed_t=False):
for inner_idx, t in enumerate(ts):
t = ts[inner_idx]
y = ys[:, inner_idx]
entries[:, idx] = base_var_casadi(t, y, inputs).full()[:, 0]
entries[:, idx] = np.array(base_var_casadi(t, y, inputs).elements())
idx += 1

# Get node and edge values
Expand Down Expand Up @@ -244,7 +249,7 @@ def initialise_2D(self):
t = ts[inner_idx]
y = ys[:, inner_idx]
entries[:, :, idx] = np.reshape(
base_var_casadi(t, y, inputs).full(),
np.array(base_var_casadi(t, y, inputs).elements()),
[first_dim_size, second_dim_size],
order="F",
)
Expand Down Expand Up @@ -395,7 +400,7 @@ def initialise_2D_scikit_fem(self):
t = ts[inner_idx]
y = ys[:, inner_idx]
entries[:, :, idx] = np.reshape(
base_var_casadi(t, y, inputs).full(),
base_var_casadi(t, y, inputs).toarray(),
[len_y, len_z],
order="C",
)
Expand Down Expand Up @@ -556,11 +561,11 @@ def initialise_sensitivity_explicit_forward(self):
"dvar_dp", [t_casadi, y_casadi, p_casadi_stacked], [dvar_dp]
)
for index, (ts, ys) in enumerate(zip(self.all_ts, self.all_ys)):
for idx, t in enumerate(ts):
u = ys[:, idx]
for inner_idx, t in enumerate(ts):
u = ys[:, inner_idx]
next_dvar_dy_eval = dvar_dy_func(t, u, inputs_stacked)
next_dvar_dp_eval = dvar_dp_func(t, u, inputs_stacked)
if index == 0 and idx == 0:
if index == 0 and inner_idx == 0:
dvar_dy_eval = next_dvar_dy_eval
dvar_dp_eval = next_dvar_dp_eval
else:
Expand Down
26 changes: 22 additions & 4 deletions pybamm/solvers/solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def __init__(
all_ys = [all_ys]
if not isinstance(all_models, list):
all_models = [all_models]

self._all_ts = all_ts
self._all_ys = all_ys
self._all_ys_and_sens = all_ys
Expand Down Expand Up @@ -317,8 +318,7 @@ def check_ys_are_not_too_large(self):
# Only check last one so that it doesn't take too long
# We only care about the cases where y is growing too large without any
# restraint, so if y gets large in the middle then comes back down that is ok
y, model = self.all_ys[-1], self.all_models[-1]
y = y[:, -1]
y, model = self.y_last, self.all_models[-1]
if np.any(y > pybamm.settings.max_y_value):
for var in [*model.rhs.keys(), *model.algebraic.keys()]:
y_var = y[model.variables[var.name].y_slices[0]]
Expand Down Expand Up @@ -354,6 +354,20 @@ def all_inputs_casadi(self):
]
return self._all_inputs_casadi

@property
def y_last(self):
try:
return self._y_last
except AttributeError:
all_ys_last = self.all_ys[-1]
if isinstance(all_ys_last, pybamm.NoMemAllocVertcat):
self._y_last = all_ys_last[:, -1]
elif all_ys_last.shape[1] == 1:
self._y_last = all_ys_last
else:
self._y_last = all_ys_last[:, -1]
return self._y_last

@property
def t_event(self):
"""Time at which the event happens"""
Expand Down Expand Up @@ -705,16 +719,19 @@ def __add__(self, other):
else:
all_ts = self.all_ts + other.all_ts
all_ys = self.all_ys + other.all_ys
all_models = self.all_models + other.all_models
all_inputs = self.all_inputs + other.all_inputs

new_sol = Solution(
all_ts,
all_ys,
self.all_models + other.all_models,
self.all_inputs + other.all_inputs,
all_models,
all_inputs,
other.t_event,
other.y_event,
other.termination,
bool(self.sensitivities),
check_solution=False,
)

new_sol.closest_event_idx = other.closest_event_idx
Expand All @@ -741,6 +758,7 @@ def copy(self):
self.t_event,
self.y_event,
self.termination,
check_solution=False,
)
new_sol._all_inputs_casadi = self.all_inputs_casadi
new_sol._sub_solutions = self.sub_solutions
Expand Down
41 changes: 41 additions & 0 deletions pybamm/solvers/solver_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#
# Utility functions and classes for solvers
#
import casadi


class NoMemAllocVertcat:
"""
Acts like a vertcat, but does not allocate new memory.
"""

def __init__(self, xs, zs, len_x=None, len_z=None, items=None):
self.xs = xs
self.zs = zs
self.len_x = len_x or xs[0].shape[0]
self.len_z = len_z or zs[0].shape[0]
len_items = len(xs)
self.shape = (self.len_x + self.len_z, len_items)

if items is None:
items = [None] * len_items
for idx in range(len_items):
out = casadi.DM.zeros((self.shape[0], 1))
out[: self.len_x] = self.xs[idx]
out[self.len_x :] = self.zs[idx]
items[idx] = out

self.items = items

def __getitem__(self, idx):
if idx[0] != slice(None):
raise NotImplementedError(
"Only full slices are supported in the first entry of the index"
)
idx = idx[1]
if isinstance(idx, slice):
return NoMemAllocVertcat(
self.xs[idx], self.zs[idx], self.len_x, self.len_z, self.items[idx]
)
else:
return self.items[idx]
2 changes: 1 addition & 1 deletion tests/integration/test_models/standard_model_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_solving(
if Crate == 0:
Crate = 1
if t_eval is None:
t_eval = np.linspace(0, 3600 / Crate, 100)
t_eval = np.linspace(0, 3600 / Crate, 1000)

self.solution = self.solver.solve(
self.model,
Expand Down
Loading