diff --git a/src/pybamm/models/base_model.py b/src/pybamm/models/base_model.py index b34e177302..094a7eef34 100644 --- a/src/pybamm/models/base_model.py +++ b/src/pybamm/models/base_model.py @@ -859,53 +859,48 @@ def set_initial_conditions_from(self, solution, inplace=True, return_type="model initial_conditions = {} if isinstance(solution, pybamm.Solution): solution = solution.last_state + + def get_final_state_eval(final_state): + if isinstance(solution, pybamm.Solution): + final_state = final_state.data + + if final_state.ndim == 0: + return np.array([final_state]) + elif final_state.ndim == 1: + return final_state[-1:] + elif final_state.ndim == 2: + return final_state[:, -1] + elif final_state.ndim == 3: + return final_state[:, :, -1].flatten(order="F") + elif final_state.ndim == 4: + return final_state[:, :, :, -1].flatten(order="F") + else: + raise NotImplementedError("Variable must be 0D, 1D, 2D, or 3D") + + def get_variable_state(var_name): + try: + return solution[var_name] + except KeyError as e: + raise pybamm.ModelError( + "To update a model from a solution, each variable in " + "model.initial_conditions must appear in the solution with " + "the same key as the variable name. In the solution provided, " + f"'{e.args[0]}' was not found." + ) from e + for var in self.initial_conditions: if isinstance(var, pybamm.Variable): - try: - final_state = solution[var.name] - except KeyError as e: - raise pybamm.ModelError( - "To update a model from a solution, each variable in " - "model.initial_conditions must appear in the solution with " - "the same key as the variable name. In the solution provided, " - f"'{e.args[0]}' was not found." - ) from e - if isinstance(solution, pybamm.Solution): - final_state = final_state.data - if final_state.ndim == 0: - final_state_eval = np.array([final_state]) - elif final_state.ndim == 1: - final_state_eval = final_state[-1:] - elif final_state.ndim == 2: - final_state_eval = final_state[:, -1] - elif final_state.ndim == 3: - final_state_eval = final_state[:, :, -1].flatten(order="F") - elif final_state.ndim == 4: - final_state_eval = final_state[:, :, :, -1].flatten(order="F") - else: - raise NotImplementedError("Variable must be 0D, 1D, 2D, or 3D") + final_state = get_variable_state(var.name) + final_state_eval = get_final_state_eval(final_state) + elif isinstance(var, pybamm.Concatenation): children = [] for child in var.orphans: - try: - final_state = solution[child.name] - except KeyError as e: - raise pybamm.ModelError( - "To update a model from a solution, each variable in " - "model.initial_conditions must appear in the solution with " - "the same key as the variable name. In the solution " - f"provided, {e.args[0]}" - ) from e - if isinstance(solution, pybamm.Solution): - final_state = final_state.data - if final_state.ndim == 2: - final_state_eval = final_state[:, -1] - else: - raise NotImplementedError( - "Variable in concatenation must be 1D" - ) + final_state = get_variable_state(child.name) + final_state_eval = get_final_state_eval(final_state) children.append(final_state_eval) final_state_eval = np.concatenate(children) + else: raise NotImplementedError( "Variable must have type 'Variable' or 'Concatenation'" diff --git a/src/pybamm/solvers/processed_variable.py b/src/pybamm/solvers/processed_variable.py index 1e49426734..4757774065 100644 --- a/src/pybamm/solvers/processed_variable.py +++ b/src/pybamm/solvers/processed_variable.py @@ -920,8 +920,10 @@ def _observe_raw_python(self): Initialise a 3D object that depends on x, y, and z or x, r, and R. """ pybamm.logger.debug("Observing the variable raw data in Python") - first_dim_size, second_dim_size, t_size = self._shape(self.t_pts) - entries = np.empty((first_dim_size, second_dim_size, t_size)) + first_dim_size, second_dim_size, third_dim_size, t_size = self._shape( + self.t_pts + ) + entries = np.empty((first_dim_size, second_dim_size, third_dim_size, t_size)) # Evaluate the base_variable index-by-index idx = 0 @@ -931,9 +933,9 @@ def _observe_raw_python(self): for inner_idx, t in enumerate(ts): t = ts[inner_idx] y = ys[:, inner_idx] - entries[:, :, idx] = np.reshape( + entries[:, :, :, idx] = np.reshape( base_var_casadi(t, y, inputs).full(), - [first_dim_size, second_dim_size], + [first_dim_size, second_dim_size, third_dim_size], order="F", ) idx += 1 @@ -1075,6 +1077,82 @@ def _shape(self, t): return [first_dim_size, second_dim_size, third_dim_size, t_size] +class ProcessedVariable3DSciKitFEM(ProcessedVariable3D): + """ + An object that can be evaluated at arbitrary (scalars or vectors) t and x, and + returns the (interpolated) value of the base variable at that t and x. + + Parameters + ---------- + base_variables : list of :class:`pybamm.Symbol` + A list of base variables with a method `evaluate(t,y)`, each entry of which + returns the value of that variable for that particular sub-solution. + A Solution can be comprised of sub-solutions which are the solutions of + different models. + Note that this can be any kind of node in the expression tree, not + just a :class:`pybamm.Variable`. + When evaluated, returns an array of size (m,n) + base_variables_casadi : list of :class:`casadi.Function` + A list of casadi functions. When evaluated, returns the same thing as + `base_Variable.evaluate` (but more efficiently). + solution : :class:`pybamm.Solution` + The solution object to be used to create the processed variables + """ + + def __init__( + self, + base_variables, + base_variables_casadi, + solution, + time_integral: Optional[pybamm.ProcessedVariableTimeIntegral] = None, + ): + self.dimensions = 3 + super(ProcessedVariable3D, self).__init__( + base_variables, + base_variables_casadi, + solution, + time_integral=time_integral, + ) + x_nodes = self.mesh.nodes + x_edges = self.mesh.edges + y_sol = self.base_variables[0].secondary_mesh.edges["y"] + z_sol = self.base_variables[0].secondary_mesh.edges["z"] + if self.base_eval_size // (len(y_sol) * len(z_sol)) == len(x_nodes): + x_sol = x_nodes + elif self.base_eval_size // (len(y_sol) * len(z_sol)) == len(x_edges): + x_sol = x_edges + + self.first_dim_size = len(x_sol) + self.second_dim_size = len(y_sol) + self.third_dim_size = len(z_sol) + + def _interp_setup(self, entries, t): + x_nodes = self.mesh.nodes + x_edges = self.mesh.edges + y_sol = self.base_variables[0].secondary_mesh.edges["y"] + z_sol = self.base_variables[0].secondary_mesh.edges["z"] + if self.base_eval_size // (len(y_sol) * len(z_sol)) == len(x_nodes): + x_sol = x_nodes + elif self.base_eval_size // (len(y_sol) * len(z_sol)) == len(x_edges): + x_sol = x_edges + + # assign attributes for reference + self.x_sol = x_sol + self.y_sol = y_sol + self.z_sol = z_sol + self.first_dimension = "x" + self.second_dimension = "y" + self.third_dimension = "z" + self.first_dim_pts = x_sol + self.second_dim_pts = y_sol + self.third_dim_pts = z_sol + + # save attributes for interpolation + coords_for_interp = {"x": x_sol, "y": y_sol, "z": z_sol, "t": t} + + return entries, coords_for_interp + + def process_variable(base_variables, *args, **kwargs): mesh = base_variables[0].mesh domain = base_variables[0].domain @@ -1095,10 +1173,7 @@ def process_variable(base_variables, *args, **kwargs): and "current collector" in base_variables[0].domains["secondary"] and isinstance(base_variables[0].secondary_mesh, pybamm.ScikitSubMesh2D) ): - raise NotImplementedError( - "3D variables with secondary domain 'current collector' using the ScikitFEM" - " discretisation are not supported as processed variables" - ) + return ProcessedVariable3DSciKitFEM(base_variables, *args, **kwargs) # check variable shape if len(base_eval_shape) == 0 or base_eval_shape[0] == 1: diff --git a/src/pybamm/solvers/processed_variable_computed.py b/src/pybamm/solvers/processed_variable_computed.py index 90be679ca3..71f67c9d66 100644 --- a/src/pybamm/solvers/processed_variable_computed.py +++ b/src/pybamm/solvers/processed_variable_computed.py @@ -546,6 +546,15 @@ def initialise_3D(self): self.r_sol = first_dim_pts self.R_sol = second_dim_pts self.x_sol = third_dim_pts + elif self.domain[0].endswith("electrode") and self.domains["secondary"] == [ + "current collector" + ]: + self.first_dimension = "x" + self.second_dimension = "y" + self.third_dimension = "z" + self.x_sol = first_dim_pts + self.y_sol = second_dim_pts + self.z_sol = third_dim_pts else: # pragma: no cover raise pybamm.DomainError( f"Cannot process 3D object with domains '{self.domains}'."