Skip to content

Commit

Permalink
#4776 processed var tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rtimms committed Jan 30, 2025
1 parent 048121a commit 04fdb3d
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 33 deletions.
11 changes: 5 additions & 6 deletions src/pybamm/solvers/processed_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1168,12 +1168,11 @@ def process_variable(base_variables, *args, **kwargs):
and isinstance(mesh, pybamm.ScikitSubMesh2D)
):
return ProcessedVariable2DSciKitFEM(base_variables, *args, **kwargs)
if (
base_variables[0].secondary_mesh
and "current collector" in base_variables[0].domains["secondary"]
and isinstance(base_variables[0].secondary_mesh, pybamm.ScikitSubMesh2D)
):
return ProcessedVariable3DSciKitFEM(base_variables, *args, **kwargs)
if hasattr(base_variables[0], "secondary_mesh"):
if "current collector" in base_variables[0].domains["secondary"] and isinstance(
base_variables[0].secondary_mesh, pybamm.ScikitSubMesh2D
):
return ProcessedVariable3DSciKitFEM(base_variables, *args, **kwargs)

# check variable shape
if len(base_eval_shape) == 0 or base_eval_shape[0] == 1:
Expand Down
45 changes: 37 additions & 8 deletions src/pybamm/solvers/processed_variable_computed.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,17 +99,23 @@ def __init__(
first_dim_nodes = self.mesh.nodes
first_dim_edges = self.mesh.edges
second_dim_pts = self.base_variables[0].secondary_mesh.nodes
if self.base_eval_size // len(second_dim_pts) not in [
if self.base_eval_size // len(second_dim_pts) in [
len(first_dim_nodes),
len(first_dim_edges),
]:
# Raise error for 3D variable
raise NotImplementedError(
f"Shape not recognized for {base_variables[0]} "
+ "(note processing of 3D variables is not yet implemented)"
)
self.initialise_2D()
return

# Try some shapes that could make the variable a 3D variable
tertiary_pts = self.base_variables[0].tertiary_mesh.nodes
if self.base_eval_size // (len(second_dim_pts) * len(tertiary_pts)) in [
len(first_dim_nodes),
len(first_dim_edges),
]:
self.initialise_3D()
return

self.initialise_2D()
raise NotImplementedError(f"Shape not recognized for {base_variables[0]}")

def add_sensitivity(self, param, data):
# unroll from sparse representation into n-d matrix
Expand Down Expand Up @@ -179,15 +185,38 @@ def unroll_2D(self, realdata=None, n_dim1=None, n_dim2=None, axis_swaps=None):
entries = np.moveaxis(entries, a, b)
return entries

def unroll_3D(
self, realdata=None, n_dim1=None, n_dim2=None, n_dim3=None, axis_swaps=None
):
if axis_swaps is None:
axis_swaps = []
if not self.unroll_params:
self.unroll_params["n_dim1"] = n_dim1
self.unroll_params["n_dim2"] = n_dim2
self.unroll_params["n_dim3"] = n_dim3
self.unroll_params["axis_swaps"] = axis_swaps
if n_dim1 is None:
n_dim1 = self.unroll_params["n_dim1"]
n_dim2 = self.unroll_params["n_dim2"]
n_dim3 = self.unroll_params["n_dim3"]
axis_swaps = self.unroll_params["axis_swaps"]
entries = np.concatenate(self._unroll_nnz(realdata), axis=0).reshape(
(len(self.t_pts), n_dim1, n_dim2, n_dim3)
)
for a, b in axis_swaps:
entries = np.moveaxis(entries, a, b)
return entries

def unroll(self, realdata=None):
if self.dimensions == 0:
return self.unroll_0D(realdata=realdata)
elif self.dimensions == 1:
return self.unroll_1D(realdata=realdata)
elif self.dimensions == 2:
return self.unroll_2D(realdata=realdata)
elif self.dimensions == 3:
return self.unroll_3D(realdata=realdata)
else:
# Raise error for 3D variable
raise NotImplementedError(f"Unsupported data dimension: {self.dimensions}")

def initialise_0D(self):
Expand Down
48 changes: 29 additions & 19 deletions tests/unit/test_solvers/test_processed_variable_computed.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,11 +446,6 @@ def test_processed_variable_2D_space_only(self):
# Check unroll function (2D)
np.testing.assert_array_equal(processed_var.unroll(), y_sol.reshape(10, 40, 1))

# Check unroll function (3D)
with pytest.raises(NotImplementedError):
processed_var.dimensions = 3
processed_var.unroll()

def test_processed_variable_2D_fixed_t_scikit(self):
var = pybamm.Variable("var", domain=["current collector"])

Expand All @@ -474,24 +469,39 @@ def test_processed_variable_2D_fixed_t_scikit(self):
processed_var.entries, np.reshape(u_sol, [len(y), len(z), len(t_sol)])
)

def test_3D_raises_error(self):
def test_processed_variable_3D(self):
var = pybamm.Variable(
"var",
domain=["negative electrode"],
auxiliary_domains={"secondary": ["current collector"]},
domain=["negative particle"],
auxiliary_domains={
"secondary": ["negative particle size"],
"tertiary": ["negative electrode"],
},
)

disc = tests.get_2p1d_discretisation_for_testing()
disc = tests.get_size_distribution_disc_for_testing(xpts=3, rpts=4, Rpts=5)
disc.set_variable_slices([var])
x_sol = disc.mesh["negative electrode"].nodes
R_sol = disc.mesh["negative particle size"].nodes
r_sol = disc.mesh["negative particle"].nodes
var_sol = disc.process_symbol(var)
t_sol = np.array([0, 1, 2])
u_sol = np.ones(var_sol.shape[0] * 3)[:, np.newaxis]
t_sol = np.linspace(0, 1, 2)
u_sol = np.ones(len(x_sol) * len(R_sol) * len(r_sol))[:, np.newaxis] * t_sol

var_casadi = to_casadi(var_sol, u_sol)
geometry_options = {"options": {"particle size": "distribution"}}
model = tests.get_base_model_with_battery_geometry(**geometry_options)
processed_var = pybamm.ProcessedVariableComputed(
[var_sol],
[var_casadi],
[u_sol],
pybamm.Solution(t_sol, u_sol, model, {}),
)

# Check shape (prim, sec, ter, time)
np.testing.assert_array_equal(
processed_var.entries,
np.reshape(u_sol, [len(r_sol), len(R_sol), len(x_sol), len(t_sol)]),
)

with pytest.raises(NotImplementedError, match="Shape not recognized"):
pybamm.ProcessedVariableComputed(
[var_sol],
[var_casadi],
[u_sol],
pybamm.Solution(t_sol, u_sol, pybamm.BaseModel(), {}),
)
# Check unroll function (3D)
np.testing.assert_array_equal(processed_var.unroll(), u_sol.reshape(6, 7, 5, 2))

0 comments on commit 04fdb3d

Please sign in to comment.