From 080d91e169542506ad99d141dda6e8da82284b0b Mon Sep 17 00:00:00 2001 From: "Eric G. Kratz" Date: Thu, 21 Sep 2023 10:02:18 -0400 Subject: [PATCH] Update quickplot to catch empty lists (#3359) * Check for empty list * Minor clean-up * Update change log and test name * Update CHANGELOG.md Co-authored-by: Ferran Brosa Planella --------- Co-authored-by: Ferran Brosa Planella --- CHANGELOG.md | 1 + pybamm/plotting/quick_plot.py | 42 +++++++++++++-------- tests/unit/test_plotting/test_quick_plot.py | 8 +++- 3 files changed, 34 insertions(+), 17 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f2d94a1dcf..8a93addf65 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ ## Bug fixes +- Fixed a bug where empty lists passed to QuickPlot resulted in an IndexError and did not return a meaningful error message ([#3359](https://github.com/pybamm-team/PyBaMM/pull/3359)) - Fixed a bug where there was a missing thermal conductivity in the thermal pouch cell models ([#3330](https://github.com/pybamm-team/PyBaMM/pull/3330)) - Fixed a bug that caused incorrect results of “{Domain} electrode thickness change [m]” due to the absence of dimension for the variable `electrode_thickness_change`([#3329](https://github.com/pybamm-team/PyBaMM/pull/3329)). - Fixed a bug that occured in `check_ys_are_not_too_large` when trying to reference `y-slice` where the referenced variable was not a `pybamm.StateVector` ([#3313](https://github.com/pybamm-team/PyBaMM/pull/3313) diff --git a/pybamm/plotting/quick_plot.py b/pybamm/plotting/quick_plot.py index 03bfeeccd4..d6828ce18a 100644 --- a/pybamm/plotting/quick_plot.py +++ b/pybamm/plotting/quick_plot.py @@ -106,21 +106,7 @@ def __init__( spatial_unit="um", variable_limits="fixed", ): - input_solutions = solutions - solutions = [] - if not isinstance(input_solutions, (pybamm.Solution, pybamm.Simulation, list)): - raise TypeError( - "solutions must be 'pybamm.Solution' or 'pybamm.Simulation' or list" - ) - elif not isinstance(input_solutions, list): - input_solutions = [input_solutions] - for sim_or_sol in input_solutions: - if isinstance(sim_or_sol, pybamm.Simulation): - # 'sim_or_sol' is actually a 'Simulation' object here so it has a - # 'Solution' attribute - solutions.append(sim_or_sol.solution) - elif isinstance(sim_or_sol, pybamm.Solution): - solutions.append(sim_or_sol) + solutions = self.preprocess_solutions(solutions) models = [solution.all_models[0] for solution in solutions] @@ -242,6 +228,32 @@ def __init__( self.set_output_variables(output_variable_tuples, solutions) self.reset_axis() + @staticmethod + def preprocess_solutions(solutions): + input_solutions = QuickPlot.check_input_validity(solutions) + processed_solutions = [] + for sim_or_sol in input_solutions: + if isinstance(sim_or_sol, pybamm.Simulation): + # 'sim_or_sol' is actually a 'Simulation' object here, so it has a + # 'Solution' attribute + processed_solutions.append(sim_or_sol.solution) + elif isinstance(sim_or_sol, pybamm.Solution): + processed_solutions.append(sim_or_sol) + return processed_solutions + + @staticmethod + def check_input_validity(input_solutions): + if not isinstance(input_solutions, (pybamm.Solution, pybamm.Simulation, list)): + raise TypeError( + "Solutions must be 'pybamm.Solution' or 'pybamm.Simulation' or list" + ) + elif not isinstance(input_solutions, list): + input_solutions = [input_solutions] + else: + if not input_solutions: + raise TypeError("QuickPlot requires at least 1 solution or simulation.") + return input_solutions + def set_output_variables(self, output_variables, solutions): # Set up output variables self.variables = {} diff --git a/tests/unit/test_plotting/test_quick_plot.py b/tests/unit/test_plotting/test_quick_plot.py index db7de574dc..3415777ee8 100644 --- a/tests/unit/test_plotting/test_quick_plot.py +++ b/tests/unit/test_plotting/test_quick_plot.py @@ -463,10 +463,14 @@ def test_plot_2plus1D_spm(self): pybamm.close_plots() - def test_failure(self): - with self.assertRaisesRegex(TypeError, "solutions must be"): + def test_invalid_input_type_failure(self): + with self.assertRaisesRegex(TypeError, "Solutions must be"): pybamm.QuickPlot(1) + def test_empty_list_failure(self): + with self.assertRaisesRegex(TypeError, "QuickPlot requires at least 1"): + pybamm.QuickPlot([]) + def test_model_with_inputs(self): parameter_values = pybamm.ParameterValues("Chen2020") model = pybamm.lithium_ion.SPMe()