Skip to content

Commit

Permalink
Update quickplot to catch empty lists (#3359)
Browse files Browse the repository at this point in the history
* Check for empty list

* Minor clean-up

* Update change log and test name

* Update CHANGELOG.md

Co-authored-by: Ferran Brosa Planella <[email protected]>

---------

Co-authored-by: Ferran Brosa Planella <[email protected]>
  • Loading branch information
kratman and brosaplanella authored Sep 21, 2023
1 parent 5fbede6 commit 080d91e
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 17 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

## Bug fixes

- Fixed a bug where empty lists passed to QuickPlot resulted in an IndexError and did not return a meaningful error message ([#3359](https://github.com/pybamm-team/PyBaMM/pull/3359))
- Fixed a bug where there was a missing thermal conductivity in the thermal pouch cell models ([#3330](https://github.com/pybamm-team/PyBaMM/pull/3330))
- Fixed a bug that caused incorrect results of “{Domain} electrode thickness change [m]” due to the absence of dimension for the variable `electrode_thickness_change`([#3329](https://github.com/pybamm-team/PyBaMM/pull/3329)).
- Fixed a bug that occured in `check_ys_are_not_too_large` when trying to reference `y-slice` where the referenced variable was not a `pybamm.StateVector` ([#3313](https://github.com/pybamm-team/PyBaMM/pull/3313)
Expand Down
42 changes: 27 additions & 15 deletions pybamm/plotting/quick_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,21 +106,7 @@ def __init__(
spatial_unit="um",
variable_limits="fixed",
):
input_solutions = solutions
solutions = []
if not isinstance(input_solutions, (pybamm.Solution, pybamm.Simulation, list)):
raise TypeError(
"solutions must be 'pybamm.Solution' or 'pybamm.Simulation' or list"
)
elif not isinstance(input_solutions, list):
input_solutions = [input_solutions]
for sim_or_sol in input_solutions:
if isinstance(sim_or_sol, pybamm.Simulation):
# 'sim_or_sol' is actually a 'Simulation' object here so it has a
# 'Solution' attribute
solutions.append(sim_or_sol.solution)
elif isinstance(sim_or_sol, pybamm.Solution):
solutions.append(sim_or_sol)
solutions = self.preprocess_solutions(solutions)

models = [solution.all_models[0] for solution in solutions]

Expand Down Expand Up @@ -242,6 +228,32 @@ def __init__(
self.set_output_variables(output_variable_tuples, solutions)
self.reset_axis()

@staticmethod
def preprocess_solutions(solutions):
input_solutions = QuickPlot.check_input_validity(solutions)
processed_solutions = []
for sim_or_sol in input_solutions:
if isinstance(sim_or_sol, pybamm.Simulation):
# 'sim_or_sol' is actually a 'Simulation' object here, so it has a
# 'Solution' attribute
processed_solutions.append(sim_or_sol.solution)
elif isinstance(sim_or_sol, pybamm.Solution):
processed_solutions.append(sim_or_sol)
return processed_solutions

@staticmethod
def check_input_validity(input_solutions):
if not isinstance(input_solutions, (pybamm.Solution, pybamm.Simulation, list)):
raise TypeError(
"Solutions must be 'pybamm.Solution' or 'pybamm.Simulation' or list"
)
elif not isinstance(input_solutions, list):
input_solutions = [input_solutions]
else:
if not input_solutions:
raise TypeError("QuickPlot requires at least 1 solution or simulation.")
return input_solutions

def set_output_variables(self, output_variables, solutions):
# Set up output variables
self.variables = {}
Expand Down
8 changes: 6 additions & 2 deletions tests/unit/test_plotting/test_quick_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,10 +463,14 @@ def test_plot_2plus1D_spm(self):

pybamm.close_plots()

def test_failure(self):
with self.assertRaisesRegex(TypeError, "solutions must be"):
def test_invalid_input_type_failure(self):
with self.assertRaisesRegex(TypeError, "Solutions must be"):
pybamm.QuickPlot(1)

def test_empty_list_failure(self):
with self.assertRaisesRegex(TypeError, "QuickPlot requires at least 1"):
pybamm.QuickPlot([])

def test_model_with_inputs(self):
parameter_values = pybamm.ParameterValues("Chen2020")
model = pybamm.lithium_ion.SPMe()
Expand Down

0 comments on commit 080d91e

Please sign in to comment.