Skip to content

Commit a24feda

Browse files
#1996 working on nomemalloc
1 parent a430116 commit a24feda

File tree

4 files changed

+72
-85
lines changed

4 files changed

+72
-85
lines changed

pybamm/solvers/casadi_solver.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,8 @@ def _solve_for_event(self, coarse_solution, init_event_signs):
309309
inputs = casadi.vertcat(*[x for x in inputs_dict.values()])
310310

311311
def find_t_event(sol, typ):
312+
sol_t = sol.all_ts[0]
313+
sol_y = sol.all_ys[0]
312314

313315
# Check most recent y to see if any events have been crossed
314316
if model.terminate_events_eval:
@@ -317,7 +319,7 @@ def find_t_event(sol, typ):
317319
init_event_signs
318320
* np.concatenate(
319321
[
320-
event(sol.t[-1], y_last, inputs)
322+
event(sol_t[-1], y_last, inputs)
321323
for event in model.terminate_events_eval
322324
]
323325
)
@@ -353,15 +355,15 @@ def f(idx):
353355
# exactly on zero, as can happen when the event switch is used
354356
# (fast with events mode)
355357
f_eval[idx] = (
356-
init_event_sign * event(sol.t[idx], sol.y[:, idx], inputs)
358+
init_event_sign * event(sol_t[idx], sol_y[:, idx], inputs)
357359
- 1e-5
358360
)
359361
return f_eval[idx]
360362

361363
def integer_bisect():
362364
a_n = 0
363-
b_n = len(sol.t) - 1
364-
for _ in range(len(sol.t)):
365+
b_n = len(sol_t) - 1
366+
for _ in range(len(sol_t)):
365367
if a_n + 1 == b_n:
366368
return a_n
367369
m_n = (a_n + b_n) // 2
@@ -383,8 +385,8 @@ def integer_bisect():
383385
# Linear interpolation between the two indices to find the root time
384386
# We could do cubic interpolation here instead but it would be
385387
# slower
386-
t_lower = sol.t[event_idx_lower]
387-
t_upper = sol.t[event_idx_lower + 1]
388+
t_lower = sol_t[event_idx_lower]
389+
t_upper = sol_t[event_idx_lower + 1]
388390
event_lower = abs(f(event_idx_lower))
389391
event_upper = abs(f(event_idx_lower + 1))
390392

@@ -400,7 +402,7 @@ def integer_bisect():
400402
t_event = np.nanmin(t_events)
401403
# create interpolant to evaluate y in the current integration
402404
# window
403-
y_sol = interp1d(sol.t, sol.y, kind="linear")
405+
y_sol = interp1d(sol_t, sol_y, kind="linear")
404406
y_event = y_sol(t_event)
405407

406408
closest_event_idx = event_idx[np.nanargmin(t_events)]
@@ -433,7 +435,7 @@ def integer_bisect():
433435
self.create_integrator(model, inputs, t_window_event_dense)
434436
use_grid = True
435437

436-
y0 = coarse_solution.y[:, event_idx_lower]
438+
y0 = coarse_solution.all_ys[0][:, event_idx_lower]
437439
dense_step_sol = self._run_integrator(
438440
model,
439441
y0,
@@ -458,8 +460,8 @@ def integer_bisect():
458460

459461
# Return solution truncated at the first coarse event time
460462
# Also assign t_event
461-
t_sol = coarse_solution.t[: event_idx_lower + 1]
462-
y_sol = coarse_solution.y[:, : event_idx_lower + 1]
463+
t_sol = coarse_solution.all_ts[0][: event_idx_lower + 1]
464+
y_sol = coarse_solution.all_ys[0][:, : event_idx_lower + 1]
463465
solution = pybamm.Solution(
464466
t_sol,
465467
y_sol,
@@ -677,6 +679,7 @@ def _run_integrator(
677679
integrator = self.integrators[model]["no grid"]
678680

679681
len_rhs = model.len_rhs
682+
len_alg = model.len_alg
680683
len_t = len(t_eval)
681684

682685
# Check y0 to see if it includes sensitivities
@@ -705,14 +708,13 @@ def _run_integrator(
705708
y_sols = x_sols
706709
else:
707710
z_sols = casadi.horzsplit(casadi_sol["zf"])
708-
y_sols = [
709-
pybamm.NoMemAllocVertcat(x, z) for x, z in zip(x_sols, z_sols)
710-
]
711+
y_sols = pybamm.NoMemAllocVertcat(x_sols, z_sols, len_rhs, len_alg)
711712
else:
712713
# Repeated calls to the integrator
713-
y_sols = [y0]
714714
x = y0_diff
715715
z = y0_alg
716+
x_sols = [x]
717+
z_sols = [z]
716718
for i in range(len_t - 1):
717719
t_min = t_eval[i]
718720
t_max = t_eval[i + 1]
@@ -724,16 +726,19 @@ def _run_integrator(
724726
integration_time = timer.time()
725727
x = casadi_sol["xf"]
726728
z = casadi_sol["zf"]
727-
if z.is_empty():
728-
y_sols.append(x)
729-
else:
730-
y_sols.append(pybamm.NoMemAllocVertcat(x, z))
729+
x_sols.append(x)
730+
if not z.is_empty():
731+
z_sols.append(z)
732+
if z.is_empty():
733+
y_sols = x_sols
734+
else:
735+
y_sols = pybamm.NoMemAllocVertcat(x_sols, z_sols, len_rhs, len_alg)
731736

732737
sol = pybamm.Solution(
733-
np.array_split(t_eval, len_t),
738+
t_eval,
734739
y_sols,
735-
[model] * len_t,
736-
[inputs_dict] * len_t,
740+
model,
741+
inputs_dict,
737742
sensitivities=extract_sensitivities_in_solution,
738743
check_solution=False,
739744
)

pybamm/solvers/processed_variable.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def __init__(self, base_variables, base_variables_casadi, solution, warn=True):
6868
# Evaluate base variable at final time
6969
self.base_eval = self.base_variables_casadi[0](
7070
self.all_ts[-1][-1], self.y_last, self.all_inputs_casadi[-1]
71-
).full()
71+
).toarray()
7272

7373
# handle 2D (in space) finite element variables differently
7474
if (
@@ -119,8 +119,8 @@ def initialise_0D(self):
119119
):
120120
for inner_idx, t in enumerate(ts):
121121
t = ts[inner_idx]
122-
y = self.getitem(ys, idx, self.out)
123-
entries[idx] = base_var_casadi(t, y, inputs).full()[0, 0]
122+
y = ys[:, inner_idx]
123+
entries[idx] = float(base_var_casadi(t, y, inputs))
124124
idx += 1
125125

126126
# set up interpolation
@@ -151,8 +151,8 @@ def initialise_1D(self, fixed_t=False):
151151
):
152152
for inner_idx, t in enumerate(ts):
153153
t = ts[inner_idx]
154-
y = self.getitem(ys, idx, self.out)
155-
entries[:, idx] = base_var_casadi(t, y, inputs).full()[:, 0]
154+
y = ys[:, inner_idx]
155+
entries[:, idx] = np.array(base_var_casadi(t, y, inputs).elements())
156156
idx += 1
157157

158158
# Get node and edge values
@@ -252,9 +252,9 @@ def initialise_2D(self):
252252
):
253253
for inner_idx, t in enumerate(ts):
254254
t = ts[inner_idx]
255-
y = self.getitem(ys, idx, self.out)
255+
y = ys[:, inner_idx]
256256
entries[:, :, idx] = np.reshape(
257-
base_var_casadi(t, y, inputs).full(),
257+
np.array(base_var_casadi(t, y, inputs).elements()),
258258
[first_dim_size, second_dim_size],
259259
order="F",
260260
)
@@ -424,9 +424,9 @@ def initialise_2D_scikit_fem(self):
424424
):
425425
for inner_idx, t in enumerate(ts):
426426
t = ts[inner_idx]
427-
y = self.getitem(ys, idx, self.out)
427+
y = ys[:, inner_idx]
428428
entries[:, :, idx] = np.reshape(
429-
base_var_casadi(t, y, inputs).full(),
429+
base_var_casadi(t, y, inputs).toarray(),
430430
[len_y, len_z],
431431
order="C",
432432
)
@@ -459,13 +459,6 @@ def initialise_2D_scikit_fem(self):
459459
bounds_error=False,
460460
)
461461

462-
def getitem(self, ys, idx, out):
463-
if self.uses_no_mem_alloc:
464-
ys.get_value(out)
465-
return out
466-
else:
467-
return ys[:, idx]
468-
469462
def __call__(self, t=None, x=None, r=None, y=None, z=None, R=None, warn=True):
470463
"""
471464
Evaluate the variable at arbitrary *dimensional* t (and x, r, y, z and/or R),
@@ -595,7 +588,7 @@ def initialise_sensitivity_explicit_forward(self):
595588
)
596589
for index, (ts, ys) in enumerate(zip(self.all_ts, self.all_ys)):
597590
for idx, t in enumerate(ts):
598-
u = ys[:, idx]
591+
u = ys[:, inner_idx]
599592
next_dvar_dy_eval = dvar_dy_func(t, u, inputs_stacked)
600593
next_dvar_dp_eval = dvar_dp_func(t, u, inputs_stacked)
601594
if index == 0 and idx == 0:

pybamm/solvers/solution.py

Lines changed: 5 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -97,25 +97,6 @@ def __init__(
9797
else:
9898
self.all_inputs = all_inputs
9999

100-
if not (
101-
len(self.all_ts) == len(self.all_ys)
102-
or len(self.all_ts) == 1
103-
or len(self.all_ys) == 1
104-
):
105-
raise ValueError("all_ts and all_ys must be the same length")
106-
if not (
107-
len(self.all_ts) == len(self.all_models)
108-
or len(self.all_ts) == 1
109-
or len(self.all_models) == 1
110-
):
111-
raise ValueError("all_ts and all_models must be the same length")
112-
if not (
113-
len(self.all_ts) == len(self.all_inputs)
114-
or len(self.all_ts) == 1
115-
or len(self.all_inputs) == 1
116-
):
117-
raise ValueError("all_ts and all_inputs must be the same length")
118-
119100
self.sensitivities = sensitivities
120101

121102
self._t_event = t_event
@@ -383,7 +364,7 @@ def y_last(self):
383364
except AttributeError:
384365
all_ys_last = self.all_ys[-1]
385366
if isinstance(all_ys_last, pybamm.NoMemAllocVertcat):
386-
self._y_last = all_ys_last.get_value()
367+
self._y_last = all_ys_last[:, -1]
387368
elif all_ys_last.shape[1] == 1:
388369
self._y_last = all_ys_last
389370
else:
@@ -751,19 +732,13 @@ def __add__(self, other):
751732
# Update list of sub-solutions
752733
if other.all_ts[0][0] == self.all_ts[-1][-1]:
753734
# Skip first time step if it is repeated
754-
if len(other.all_ts[0]) == 1:
755-
all_ts = self.all_ts + other.all_ts[1:]
756-
all_ys = self.all_ys + other.all_ys[1:]
757-
all_models = self.all_models + other.all_models[1:]
758-
all_inputs = self.all_inputs + other.all_inputs[1:]
759-
else:
760-
all_ts = self.all_ts + [other.all_ts[0][1:]] + other.all_ts[1:]
761-
all_ys = self.all_ys + [other.all_ys[0][:, 1:]] + other.all_ys[1:]
762-
all_models = (self.all_models + other.all_models,)
763-
all_inputs = (self.all_inputs + other.all_inputs,)
735+
all_ts = self.all_ts + [other.all_ts[0][1:]] + other.all_ts[1:]
736+
all_ys = self.all_ys + [other.all_ys[0][:, 1:]] + other.all_ys[1:]
764737
else:
765738
all_ts = self.all_ts + other.all_ts
766739
all_ys = self.all_ys + other.all_ys
740+
all_models = self.all_models + other.all_models
741+
all_inputs = self.all_inputs + other.all_inputs
767742

768743
new_sol = Solution(
769744
all_ts,

pybamm/solvers/solver_utils.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,43 @@
11
#
22
# Utility functions and classes for solvers
33
#
4-
import numpy as np
4+
import casadi
55

66

77
class NoMemAllocVertcat:
88
"""
99
Acts like a vertcat, but does not allocate new memory.
1010
"""
1111

12-
def __init__(self, a, b):
13-
arrays = [a, b]
14-
self.arrays = arrays
15-
self.len_a = a.shape[0]
16-
self.len_b = b.shape[0]
17-
self.len = self.len_a + self.len_b
18-
self._shape = (self.len, 1)
12+
def __init__(self, xs, ys, len_x=None, len_y=None, items=None):
13+
self.xs = xs
14+
self.ys = ys
15+
self.len_x = len_x or xs[0].shape[0]
16+
self.len_y = len_y or ys[0].shape[0]
17+
len_items = len(xs)
18+
self.shape = (self.len_x + self.len_y, len_items)
1919

20-
@property
21-
def shape(self):
22-
return self._shape
20+
if items is None:
21+
items = [None] * len_items
22+
self.items = items
2323

24-
def get_value(self, out=None):
25-
if out is None:
26-
out = np.empty((self.len, 1))
27-
out[: self.len_a] = self.arrays[0]
28-
out[self.len_a :] = self.arrays[1]
29-
return out
24+
def __getitem__(self, idx):
25+
if idx[0] != slice(None):
26+
raise NotImplementedError(
27+
"Only full slices are supported in the first entry of the index"
28+
)
29+
idx = idx[1]
30+
if isinstance(idx, slice):
31+
return NoMemAllocVertcat(
32+
self.xs[idx], self.ys[idx], self.len_x, self.len_y, self.items[idx]
33+
)
34+
else:
35+
item = self.items[idx]
36+
if item is not None:
37+
return item
38+
else:
39+
out = casadi.DM.zeros((self.shape[0], 1))
40+
out[: self.len_x] = self.xs[idx]
41+
out[self.len_x :] = self.ys[idx]
42+
self.items[idx] = out
43+
return out

0 commit comments

Comments
 (0)