Skip to content

Commit a430116

Browse files
#1996 working pass at nomemalloc
1 parent d98ea9b commit a430116

File tree

5 files changed

+98
-75
lines changed

5 files changed

+98
-75
lines changed

pybamm/solvers/base_solver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1294,7 +1294,7 @@ def step(
12941294
t = old_solution.all_ts[-1][-1]
12951295
if old_solution.all_models[-1] == model:
12961296
# initialize with old solution
1297-
model.y0 = old_solution.all_ys[-1][:, -1]
1297+
model.y0 = old_solution.y_last
12981298
else:
12991299
model.y0 = model.set_initial_conditions_from(
13001300
old_solution

pybamm/solvers/casadi_solver.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ def _integrate(self, model, t_eval, inputs_dict=None):
284284
# update time
285285
t = t_window[-1]
286286
# update y0
287-
y0 = solution.all_ys[-1][:, -1]
287+
y0 = solution.y_last
288288

289289
# now we extract sensitivities from the solution
290290
if bool(model.calculate_sensitivities):
@@ -312,7 +312,7 @@ def find_t_event(sol, typ):
312312

313313
# Check most recent y to see if any events have been crossed
314314
if model.terminate_events_eval:
315-
y_last = sol.all_ys[-1][:, -1]
315+
y_last = sol.y_last
316316
crossed_events = np.sign(
317317
init_event_signs
318318
* np.concatenate(
@@ -676,7 +676,8 @@ def _run_integrator(
676676
else:
677677
integrator = self.integrators[model]["no grid"]
678678

679-
len_rhs = model.concatenated_rhs.size
679+
len_rhs = model.len_rhs
680+
len_t = len(t_eval)
680681

681682
# Check y0 to see if it includes sensitivities
682683
if explicit_sensitivities:
@@ -699,17 +700,20 @@ def _run_integrator(
699700
x0=y0_diff, z0=y0_alg, p=inputs_with_tmin, **self.extra_options_call
700701
)
701702
integration_time = timer.time()
703+
x_sols = casadi.horzsplit(casadi_sol["xf"])
702704
if casadi_sol["zf"].is_empty():
703-
y_sol = casadi_sol["xf"]
705+
y_sols = x_sols
704706
else:
705-
y_sol = pybamm.NoMemAllocVertcat(casadi_sol["xf"], casadi_sol["zf"])
707+
z_sols = casadi.horzsplit(casadi_sol["zf"])
708+
y_sols = [
709+
pybamm.NoMemAllocVertcat(x, z) for x, z in zip(x_sols, z_sols)
710+
]
706711
else:
707712
# Repeated calls to the integrator
713+
y_sols = [y0]
708714
x = y0_diff
709715
z = y0_alg
710-
y_diff = x
711-
y_alg = z
712-
for i in range(len(t_eval) - 1):
716+
for i in range(len_t - 1):
713717
t_min = t_eval[i]
714718
t_max = t_eval[i + 1]
715719
inputs_with_tlims = casadi.vertcat(inputs, t_min, t_max)
@@ -720,19 +724,16 @@ def _run_integrator(
720724
integration_time = timer.time()
721725
x = casadi_sol["xf"]
722726
z = casadi_sol["zf"]
723-
y_diff = casadi.horzcat(y_diff, x)
724-
if not z.is_empty():
725-
y_alg = casadi.horzcat(y_alg, z)
726-
if z.is_empty():
727-
y_sol = y_diff
728-
else:
729-
y_sol = pybamm.NoMemAllocVertcat(y_diff, y_alg)
727+
if z.is_empty():
728+
y_sols.append(x)
729+
else:
730+
y_sols.append(pybamm.NoMemAllocVertcat(x, z))
730731

731732
sol = pybamm.Solution(
732-
t_eval,
733-
y_sol,
734-
model,
735-
inputs_dict,
733+
np.array_split(t_eval, len_t),
734+
y_sols,
735+
[model] * len_t,
736+
[inputs_dict] * len_t,
736737
sensitivities=extract_sensitivities_in_solution,
737738
check_solution=False,
738739
)

pybamm/solvers/processed_variable.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,14 @@ def __init__(self, base_variables, base_variables_casadi, solution, warn=True):
3939

4040
self.all_ts = solution.all_ts
4141
self.all_ys = solution.all_ys
42+
self.y_last = solution.y_last
4243
self.all_inputs = solution.all_inputs
4344
self.all_inputs_casadi = solution.all_inputs_casadi
4445

46+
# utilities for no_mem_alloc
47+
self.uses_no_mem_alloc = isinstance(self.all_ys[0], pybamm.NoMemAllocVertcat)
48+
self.out = casadi.DM.zeros(self.all_ys[0].shape)
49+
4550
self.mesh = base_variables[0].mesh
4651
self.domain = base_variables[0].domain
4752
self.domains = base_variables[0].domains
@@ -60,9 +65,9 @@ def __init__(self, base_variables, base_variables_casadi, solution, warn=True):
6065
# Store length scales
6166
self.length_scales = solution.length_scales_eval
6267

63-
# Evaluate base variable at initial time
68+
# Evaluate base variable at final time
6469
self.base_eval = self.base_variables_casadi[0](
65-
self.all_ts[0][0], self.all_ys[0][:, 0], self.all_inputs_casadi[0]
70+
self.all_ts[-1][-1], self.y_last, self.all_inputs_casadi[-1]
6671
).full()
6772

6873
# handle 2D (in space) finite element variables differently
@@ -114,7 +119,7 @@ def initialise_0D(self):
114119
):
115120
for inner_idx, t in enumerate(ts):
116121
t = ts[inner_idx]
117-
y = ys[:, inner_idx]
122+
y = self.getitem(ys, idx, self.out)
118123
entries[idx] = base_var_casadi(t, y, inputs).full()[0, 0]
119124
idx += 1
120125

@@ -146,7 +151,7 @@ def initialise_1D(self, fixed_t=False):
146151
):
147152
for inner_idx, t in enumerate(ts):
148153
t = ts[inner_idx]
149-
y = ys[:, inner_idx]
154+
y = self.getitem(ys, idx, self.out)
150155
entries[:, idx] = base_var_casadi(t, y, inputs).full()[:, 0]
151156
idx += 1
152157

@@ -247,7 +252,7 @@ def initialise_2D(self):
247252
):
248253
for inner_idx, t in enumerate(ts):
249254
t = ts[inner_idx]
250-
y = ys[:, inner_idx]
255+
y = self.getitem(ys, idx, self.out)
251256
entries[:, :, idx] = np.reshape(
252257
base_var_casadi(t, y, inputs).full(),
253258
[first_dim_size, second_dim_size],
@@ -419,7 +424,7 @@ def initialise_2D_scikit_fem(self):
419424
):
420425
for inner_idx, t in enumerate(ts):
421426
t = ts[inner_idx]
422-
y = ys[:, inner_idx]
427+
y = self.getitem(ys, idx, self.out)
423428
entries[:, :, idx] = np.reshape(
424429
base_var_casadi(t, y, inputs).full(),
425430
[len_y, len_z],
@@ -454,6 +459,13 @@ def initialise_2D_scikit_fem(self):
454459
bounds_error=False,
455460
)
456461

462+
def getitem(self, ys, idx, out):
463+
if self.uses_no_mem_alloc:
464+
ys.get_value(out)
465+
return out
466+
else:
467+
return ys[:, idx]
468+
457469
def __call__(self, t=None, x=None, r=None, y=None, z=None, R=None, warn=True):
458470
"""
459471
Evaluate the variable at arbitrary *dimensional* t (and x, r, y, z and/or R),

pybamm/solvers/solution.py

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def __init__(
8181
all_ys = [all_ys]
8282
if not isinstance(all_models, list):
8383
all_models = [all_models]
84+
8485
self._all_ts = all_ts
8586
self._all_ys = all_ys
8687
self._all_ys_and_sens = all_ys
@@ -96,6 +97,25 @@ def __init__(
9697
else:
9798
self.all_inputs = all_inputs
9899

100+
if not (
101+
len(self.all_ts) == len(self.all_ys)
102+
or len(self.all_ts) == 1
103+
or len(self.all_ys) == 1
104+
):
105+
raise ValueError("all_ts and all_ys must be the same length")
106+
if not (
107+
len(self.all_ts) == len(self.all_models)
108+
or len(self.all_ts) == 1
109+
or len(self.all_models) == 1
110+
):
111+
raise ValueError("all_ts and all_models must be the same length")
112+
if not (
113+
len(self.all_ts) == len(self.all_inputs)
114+
or len(self.all_ts) == 1
115+
or len(self.all_inputs) == 1
116+
):
117+
raise ValueError("all_ts and all_inputs must be the same length")
118+
99119
self.sensitivities = sensitivities
100120

101121
self._t_event = t_event
@@ -106,7 +126,6 @@ def __init__(
106126
isinstance(v, casadi.MX) for v in self.all_inputs[0].values()
107127
)
108128

109-
# Check no ys are too large
110129
if check_solution and not self.has_symbolic_inputs:
111130
self.check_ys_are_not_too_large()
112131

@@ -321,8 +340,7 @@ def check_ys_are_not_too_large(self):
321340
# Only check last one so that it doesn't take too long
322341
# We only care about the cases where y is growing too large without any
323342
# restraint, so if y gets large in the middle then comes back down that is ok
324-
y, model = self.all_ys[-1], self.all_models[-1]
325-
y = y[:, -1]
343+
y, model = self.y_last, self.all_models[-1]
326344
if np.any(y > pybamm.settings.max_y_value):
327345
for var in [*model.rhs.keys(), *model.algebraic.keys()]:
328346
y_var = y[model.variables[var.name].y_slices[0]]
@@ -358,6 +376,20 @@ def all_inputs_casadi(self):
358376
]
359377
return self._all_inputs_casadi
360378

379+
@property
380+
def y_last(self):
381+
try:
382+
return self._y_last
383+
except AttributeError:
384+
all_ys_last = self.all_ys[-1]
385+
if isinstance(all_ys_last, pybamm.NoMemAllocVertcat):
386+
self._y_last = all_ys_last.get_value()
387+
elif all_ys_last.shape[1] == 1:
388+
self._y_last = all_ys_last
389+
else:
390+
self._y_last = all_ys_last[:, -1]
391+
return self._y_last
392+
361393
@property
362394
def t_event(self):
363395
"""Time at which the event happens"""
@@ -719,21 +751,30 @@ def __add__(self, other):
719751
# Update list of sub-solutions
720752
if other.all_ts[0][0] == self.all_ts[-1][-1]:
721753
# Skip first time step if it is repeated
722-
all_ts = self.all_ts + [other.all_ts[0][1:]] + other.all_ts[1:]
723-
all_ys = self.all_ys + [other.all_ys[0][:, 1:]] + other.all_ys[1:]
754+
if len(other.all_ts[0]) == 1:
755+
all_ts = self.all_ts + other.all_ts[1:]
756+
all_ys = self.all_ys + other.all_ys[1:]
757+
all_models = self.all_models + other.all_models[1:]
758+
all_inputs = self.all_inputs + other.all_inputs[1:]
759+
else:
760+
all_ts = self.all_ts + [other.all_ts[0][1:]] + other.all_ts[1:]
761+
all_ys = self.all_ys + [other.all_ys[0][:, 1:]] + other.all_ys[1:]
762+
all_models = (self.all_models + other.all_models,)
763+
all_inputs = (self.all_inputs + other.all_inputs,)
724764
else:
725765
all_ts = self.all_ts + other.all_ts
726766
all_ys = self.all_ys + other.all_ys
727767

728768
new_sol = Solution(
729769
all_ts,
730770
all_ys,
731-
self.all_models + other.all_models,
732-
self.all_inputs + other.all_inputs,
771+
all_models,
772+
all_inputs,
733773
other.t_event,
734774
other.y_event,
735775
other.termination,
736776
bool(self.sensitivities),
777+
check_solution=False,
737778
)
738779

739780
new_sol.closest_event_idx = other.closest_event_idx
@@ -769,6 +810,7 @@ def copy(self):
769810
self.t_event,
770811
self.y_event,
771812
self.termination,
813+
check_solution=False,
772814
)
773815
new_sol._all_inputs_casadi = self.all_inputs_casadi
774816
new_sol._sub_solutions = self.sub_solutions

pybamm/solvers/solver_utils.py

Lines changed: 10 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#
22
# Utility functions and classes for solvers
33
#
4+
import numpy as np
45

56

67
class NoMemAllocVertcat:
@@ -11,51 +12,18 @@ class NoMemAllocVertcat:
1112
def __init__(self, a, b):
1213
arrays = [a, b]
1314
self.arrays = arrays
14-
15-
for array in arrays:
16-
if not 1 <= len(array.shape) <= 2:
17-
raise ValueError("Only 1D or 2D arrays are supported")
18-
self._ndim = len(array.shape)
19-
2015
self.len_a = a.shape[0]
21-
shape0 = a.shape[0] + b.shape[0]
22-
23-
if self._ndim == 1:
24-
self._shape = (shape0,)
25-
self._size = shape0
26-
else:
27-
if a.shape[1] != b.shape[1]:
28-
raise ValueError("All arrays must have the same number of columns")
29-
shape1 = a.shape[1]
30-
31-
self._shape = (shape0, shape1)
32-
self._size = shape0 * shape1
16+
self.len_b = b.shape[0]
17+
self.len = self.len_a + self.len_b
18+
self._shape = (self.len, 1)
3319

3420
@property
3521
def shape(self):
3622
return self._shape
3723

38-
@property
39-
def size(self):
40-
return self._size
41-
42-
@property
43-
def ndim(self):
44-
return self._ndim
45-
46-
def __getitem__(self, key):
47-
if self._ndim == 1 or isinstance(key, int):
48-
if key < self.len_a:
49-
return self.arrays[0][key]
50-
else:
51-
return self.arrays[1][key - self.len_a]
52-
53-
if key[0] == slice(None):
54-
return NoMemAllocVertcat(*[arr[:, key[1]] for arr in self.arrays])
55-
elif isinstance(key[0], int):
56-
if key[0] < self.len_a:
57-
return self.arrays[0][key[0], key[1]]
58-
else:
59-
return self.arrays[1][key[0] - self.len_a, key[1]]
60-
else:
61-
raise NotImplementedError
24+
def get_value(self, out=None):
25+
if out is None:
26+
out = np.empty((self.len, 1))
27+
out[: self.len_a] = self.arrays[0]
28+
out[self.len_a :] = self.arrays[1]
29+
return out

0 commit comments

Comments
 (0)