Skip to content

Commit

Permalink
#4776 handle current collector case
Browse files Browse the repository at this point in the history
  • Loading branch information
rtimms committed Jan 27, 2025
1 parent 122ca4c commit 05c984e
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 48 deletions.
75 changes: 35 additions & 40 deletions src/pybamm/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,53 +859,48 @@ def set_initial_conditions_from(self, solution, inplace=True, return_type="model
initial_conditions = {}
if isinstance(solution, pybamm.Solution):
solution = solution.last_state

def get_final_state_eval(final_state):
if isinstance(solution, pybamm.Solution):
final_state = final_state.data

if final_state.ndim == 0:
return np.array([final_state])
elif final_state.ndim == 1:
return final_state[-1:]
elif final_state.ndim == 2:
return final_state[:, -1]
elif final_state.ndim == 3:
return final_state[:, :, -1].flatten(order="F")
elif final_state.ndim == 4:
return final_state[:, :, :, -1].flatten(order="F")
else:
raise NotImplementedError("Variable must be 0D, 1D, 2D, or 3D")

def get_variable_state(var_name):
try:
return solution[var_name]
except KeyError as e:
raise pybamm.ModelError(
"To update a model from a solution, each variable in "
"model.initial_conditions must appear in the solution with "
"the same key as the variable name. In the solution provided, "
f"'{e.args[0]}' was not found."
) from e

for var in self.initial_conditions:
if isinstance(var, pybamm.Variable):
try:
final_state = solution[var.name]
except KeyError as e:
raise pybamm.ModelError(
"To update a model from a solution, each variable in "
"model.initial_conditions must appear in the solution with "
"the same key as the variable name. In the solution provided, "
f"'{e.args[0]}' was not found."
) from e
if isinstance(solution, pybamm.Solution):
final_state = final_state.data
if final_state.ndim == 0:
final_state_eval = np.array([final_state])
elif final_state.ndim == 1:
final_state_eval = final_state[-1:]
elif final_state.ndim == 2:
final_state_eval = final_state[:, -1]
elif final_state.ndim == 3:
final_state_eval = final_state[:, :, -1].flatten(order="F")
elif final_state.ndim == 4:
final_state_eval = final_state[:, :, :, -1].flatten(order="F")
else:
raise NotImplementedError("Variable must be 0D, 1D, 2D, or 3D")
final_state = get_variable_state(var.name)
final_state_eval = get_final_state_eval(final_state)

elif isinstance(var, pybamm.Concatenation):
children = []
for child in var.orphans:
try:
final_state = solution[child.name]
except KeyError as e:
raise pybamm.ModelError(
"To update a model from a solution, each variable in "
"model.initial_conditions must appear in the solution with "
"the same key as the variable name. In the solution "
f"provided, {e.args[0]}"
) from e
if isinstance(solution, pybamm.Solution):
final_state = final_state.data
if final_state.ndim == 2:
final_state_eval = final_state[:, -1]
else:
raise NotImplementedError(
"Variable in concatenation must be 1D"
)
final_state = get_variable_state(child.name)
final_state_eval = get_final_state_eval(final_state)
children.append(final_state_eval)
final_state_eval = np.concatenate(children)

else:
raise NotImplementedError(
"Variable must have type 'Variable' or 'Concatenation'"
Expand Down
91 changes: 83 additions & 8 deletions src/pybamm/solvers/processed_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,8 +920,10 @@ def _observe_raw_python(self):
Initialise a 3D object that depends on x, y, and z or x, r, and R.
"""
pybamm.logger.debug("Observing the variable raw data in Python")
first_dim_size, second_dim_size, t_size = self._shape(self.t_pts)
entries = np.empty((first_dim_size, second_dim_size, t_size))
first_dim_size, second_dim_size, third_dim_size, t_size = self._shape(
self.t_pts
)
entries = np.empty((first_dim_size, second_dim_size, third_dim_size, t_size))

# Evaluate the base_variable index-by-index
idx = 0
Expand All @@ -931,9 +933,9 @@ def _observe_raw_python(self):
for inner_idx, t in enumerate(ts):
t = ts[inner_idx]
y = ys[:, inner_idx]
entries[:, :, idx] = np.reshape(
entries[:, :, :, idx] = np.reshape(
base_var_casadi(t, y, inputs).full(),
[first_dim_size, second_dim_size],
[first_dim_size, second_dim_size, third_dim_size],
order="F",
)
idx += 1
Expand Down Expand Up @@ -1075,6 +1077,82 @@ def _shape(self, t):
return [first_dim_size, second_dim_size, third_dim_size, t_size]


class ProcessedVariable3DSciKitFEM(ProcessedVariable3D):
"""
An object that can be evaluated at arbitrary (scalars or vectors) t and x, and
returns the (interpolated) value of the base variable at that t and x.
Parameters
----------
base_variables : list of :class:`pybamm.Symbol`
A list of base variables with a method `evaluate(t,y)`, each entry of which
returns the value of that variable for that particular sub-solution.
A Solution can be comprised of sub-solutions which are the solutions of
different models.
Note that this can be any kind of node in the expression tree, not
just a :class:`pybamm.Variable`.
When evaluated, returns an array of size (m,n)
base_variables_casadi : list of :class:`casadi.Function`
A list of casadi functions. When evaluated, returns the same thing as
`base_Variable.evaluate` (but more efficiently).
solution : :class:`pybamm.Solution`
The solution object to be used to create the processed variables
"""

def __init__(
self,
base_variables,
base_variables_casadi,
solution,
time_integral: Optional[pybamm.ProcessedVariableTimeIntegral] = None,
):
self.dimensions = 3
super(ProcessedVariable3D, self).__init__(
base_variables,
base_variables_casadi,
solution,
time_integral=time_integral,
)
x_nodes = self.mesh.nodes
x_edges = self.mesh.edges
y_sol = self.base_variables[0].secondary_mesh.edges["y"]
z_sol = self.base_variables[0].secondary_mesh.edges["z"]
if self.base_eval_size // (len(y_sol) * len(z_sol)) == len(x_nodes):
x_sol = x_nodes
elif self.base_eval_size // (len(y_sol) * len(z_sol)) == len(x_edges):
x_sol = x_edges

self.first_dim_size = len(x_sol)
self.second_dim_size = len(y_sol)
self.third_dim_size = len(z_sol)

def _interp_setup(self, entries, t):
x_nodes = self.mesh.nodes
x_edges = self.mesh.edges
y_sol = self.base_variables[0].secondary_mesh.edges["y"]
z_sol = self.base_variables[0].secondary_mesh.edges["z"]
if self.base_eval_size // (len(y_sol) * len(z_sol)) == len(x_nodes):
x_sol = x_nodes
elif self.base_eval_size // (len(y_sol) * len(z_sol)) == len(x_edges):
x_sol = x_edges

# assign attributes for reference
self.x_sol = x_sol
self.y_sol = y_sol
self.z_sol = z_sol
self.first_dimension = "x"
self.second_dimension = "y"
self.third_dimension = "z"
self.first_dim_pts = x_sol
self.second_dim_pts = y_sol
self.third_dim_pts = z_sol

# save attributes for interpolation
coords_for_interp = {"x": x_sol, "y": y_sol, "z": z_sol, "t": t}

return entries, coords_for_interp


def process_variable(base_variables, *args, **kwargs):
mesh = base_variables[0].mesh
domain = base_variables[0].domain
Expand All @@ -1095,10 +1173,7 @@ def process_variable(base_variables, *args, **kwargs):
and "current collector" in base_variables[0].domains["secondary"]
and isinstance(base_variables[0].secondary_mesh, pybamm.ScikitSubMesh2D)
):
raise NotImplementedError(
"3D variables with secondary domain 'current collector' using the ScikitFEM"
" discretisation are not supported as processed variables"
)
return ProcessedVariable3DSciKitFEM(base_variables, *args, **kwargs)

# check variable shape
if len(base_eval_shape) == 0 or base_eval_shape[0] == 1:
Expand Down
9 changes: 9 additions & 0 deletions src/pybamm/solvers/processed_variable_computed.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,15 @@ def initialise_3D(self):
self.r_sol = first_dim_pts
self.R_sol = second_dim_pts
self.x_sol = third_dim_pts
elif self.domain[0].endswith("electrode") and self.domains["secondary"] == [
"current collector"
]:
self.first_dimension = "x"
self.second_dimension = "y"
self.third_dimension = "z"
self.x_sol = first_dim_pts
self.y_sol = second_dim_pts
self.z_sol = third_dim_pts
else: # pragma: no cover
raise pybamm.DomainError(
f"Cannot process 3D object with domains '{self.domains}'."
Expand Down

0 comments on commit 05c984e

Please sign in to comment.