diff --git a/run_simulation_main.py b/run_simulation_main.py index fb0b681b..70e5c24a 100644 --- a/run_simulation_main.py +++ b/run_simulation_main.py @@ -115,6 +115,12 @@ 'If provided, overrides the default output directory.', ) +_PLOT_CONFIG_MODULE = flags.DEFINE_string( + 'plot_config', + 'torax.plotting.configs.default_plot_config', # Default + 'Module path to the plot config.', +) + jax.config.parse_flags_with_absl() @@ -254,10 +260,8 @@ def change_config( new_runtime_params = build_sim.build_runtime_params_from_config( sim_config['runtime_params'] ) - new_geo_provider = ( - build_sim.build_geometry_provider_from_config( - sim_config['geometry'], - ) + new_geo_provider = build_sim.build_geometry_provider_from_config( + sim_config['geometry'], ) new_transport_model_builder = ( build_sim.build_transport_model_builder_from_config( @@ -426,17 +430,28 @@ def _post_run_plotting( color=simulation_app.AnsiColors.RED, ) return + try: + plot_config = config_loader.import_module( + _PLOT_CONFIG_MODULE.value + ).PLOT_CONFIG + except (ModuleNotFoundError, AttributeError) as e: + logging.exception( + 'Error loading plot config module %s: %s', _PLOT_CONFIG_MODULE.value, e + ) + return match input_text: case '0': - return plotruns_lib.plot_run(output_files[-1]) + return plotruns_lib.plot_run(plot_config, output_files[-1]) case '1': if len(output_files) == 1: simulation_app.log_to_stdout( 'Only one output run file found, only plotting the last run.', color=simulation_app.AnsiColors.RED, ) - return plotruns_lib.plot_run(output_files[-1]) - return plotruns_lib.plot_run(output_files[-1], output_files[-2]) + return plotruns_lib.plot_run(plot_config, output_files[-1]) + return plotruns_lib.plot_run( + plot_config, output_files[-1], output_files[-2] + ) case '2': reference_run = _REFERENCE_RUN.value if reference_run is None: @@ -444,7 +459,7 @@ def _post_run_plotting( 'No reference run provided, only plotting the last run.', color=simulation_app.AnsiColors.RED, ) - return plotruns_lib.plot_run(output_files[-1], reference_run) + return plotruns_lib.plot_run(plot_config, output_files[-1], reference_run) case _: raise ValueError('Unknown command') diff --git a/torax/plotting/configs/default_plot_config.py b/torax/plotting/configs/default_plot_config.py new file mode 100644 index 00000000..5d641a9e --- /dev/null +++ b/torax/plotting/configs/default_plot_config.py @@ -0,0 +1,61 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Default plotting configuration for Torax runs.""" + +from torax.plotting import plotruns_lib + +PLOT_CONFIG = plotruns_lib.FigureProperties( + rows=2, + cols=3, + axes=( + # For chi, set histogram percentile for y-axis upper limit, due to + # volatile nature of the data. Do not include first timepoint since + # chi is defined as zero there and may unduly affect ylim. + plotruns_lib.PlotProperties( + attrs=('chi_i', 'chi_e'), + labels=(r'$\chi_i$', r'$\chi_e$'), + ylabel=r'Heat conductivity $[m^2/s]$', + upper_percentile=98.0, + include_first_timepoint=False, + ylim_min_zero=False, + ), + plotruns_lib.PlotProperties( + attrs=('ti', 'te'), + labels=(r'$T_i$', r'$T_e$'), + ylabel='Temperature [keV]', + ), + plotruns_lib.PlotProperties( + attrs=('ne',), + labels=(r'$n_e$',), + ylabel=r'Electron density $[10^{20}~m^{-3}]$', + ), + plotruns_lib.PlotProperties( + attrs=('j', 'johm', 'j_bootstrap', 'jext'), + labels=(r'$j_{tot}$', r'$j_{ohm}$', r'$j_{bs}$', r'$j_{ext}$'), + ylabel=r'Toroidal current $[A~m^{-2}]$', + legend_fontsize=8, # Smaller fontsize for this plot + ), + plotruns_lib.PlotProperties( + attrs=('q',), + labels=(r'$q$',), + ylabel='Safety factor', + ), + plotruns_lib.PlotProperties( + attrs=('s',), + labels=(r'$\hat{s}$',), + ylabel='Magnetic shear', + ), + ), +) diff --git a/torax/plotting/plotruns.py b/torax/plotting/plotruns.py index 66ec4acf..429f97b7 100644 --- a/torax/plotting/plotruns.py +++ b/torax/plotting/plotruns.py @@ -16,15 +16,11 @@ Includes a time slider. Reads output files with xarray data or legacy h5 data. -Plots: -(1) chi_i, chi_e (transport coefficients) -(2) Ti, Te (temperatures) -(3) ne (density) -(4) jtot, johm (total and ohmic plasma current) -(5) q (safety factor) -(6) s (magnetic shear) +Plots are configured by a plot_config module. """ +import importlib from absl import app +from absl import logging from absl.flags import argparse_flags import matplotlib from torax.plotting import plotruns_lib @@ -34,6 +30,7 @@ def parse_flags(_): + """Parse flags for the plotting tool.""" parser = argparse_flags.ArgumentParser(description='Plot finished run') parser.add_argument( '--outfile', @@ -43,15 +40,28 @@ def parse_flags(_): ' comparison is done)' ), ) - parser.set_defaults(normalized=True) + parser.add_argument( + '--plot_config', + default='torax.plotting.configs.default_plot_config', + help='Name of the plot config module.', + ) return parser.parse_args() def main(args): + plot_config_module_path = args.plot_config + try: + plot_config_module = importlib.import_module(plot_config_module_path) + plot_config = plot_config_module.PLOT_CONFIG + except (ModuleNotFoundError, AttributeError) as e: + logging.exception( + 'Error loading plot config: %s: %s', plot_config_module_path, e + ) + raise if len(args.outfile) == 1: - plotruns_lib.plot_run(args.outfile[0]) + plotruns_lib.plot_run(plot_config, args.outfile[0]) else: - plotruns_lib.plot_run(args.outfile[0], args.outfile[1]) + plotruns_lib.plot_run(plot_config, args.outfile[0], args.outfile[1]) if __name__ == '__main__': diff --git a/torax/plotting/plotruns_lib.py b/torax/plotting/plotruns_lib.py index 70710296..dd79ef0a 100644 --- a/torax/plotting/plotruns_lib.py +++ b/torax/plotting/plotruns_lib.py @@ -16,7 +16,6 @@ from collections.abc import Sequence import dataclasses -import functools from os import path from typing import Any @@ -27,6 +26,40 @@ from torax import output import xarray as xr +# Constants for figure setup, plot labels, and formatting. +# The axes are designed to be plotted in the order they appear in the list, +# first ascending in columns, then rows. + + +@dataclasses.dataclass +class PlotProperties: + """Dataclass for individual plot properties.""" + + attrs: tuple[str, ...] + labels: tuple[str, ...] + ylabel: str + legend_fontsize: int | None = None # None reverts to default matplotlib value + upper_percentile: float = 100.0 + lower_percentile: float = 0.0 + include_first_timepoint: bool = True + ylim_min_zero: bool = True + + +@dataclasses.dataclass +class FigureProperties: + """Dataclass for all figure related data.""" + + rows: int + cols: int + axes: tuple[PlotProperties, ...] + figure_size_factor: float = 5 + default_legend_fontsize: int = 10 + colors: tuple[str, ...] = ('r', 'b', 'g', 'm', 'y', 'c') + + def __post_init__(self): + if len(self.axes) > self.rows * self.cols: + raise ValueError('len(axes) in plot_config is more than rows * columns.') + @dataclasses.dataclass class PlotData: @@ -47,60 +80,86 @@ class PlotData: rho_cell_coord: np.ndarray rho_face_coord: np.ndarray - def __post_init__(self): - self.tmin = min(self.t) - self.tmax = max(self.t) - self.ymax_t = np.amax([self.ti, self.te]) - self.ymax_n = np.amax(self.ne) - self.ymax_j = np.amax([np.amax(self.j), np.amax(self.johm)]) - self.ymin_j = np.amin([np.amin(self.j), np.amin(self.johm)]) - self.ymin_j = np.amin(self.j) - self.ymax_q = np.amax(self.q) - self.ymax_s = np.amax(self.s) - self.ymin_s = np.amin(self.s) - # avoid initial condition for chi ymax, since can be unphysically high - self.ymax_chi_i = np.amax(self.chi_i[1:, :]) - self.ymax_chi_e = np.amax(self.chi_e[1:, :]) - self.dt = min(np.diff(self.t)) - - -def plot_run(outfile: str, outfile2: str | None = None): + +def load_data(filename: str) -> PlotData: + """Loads an xr.Dataset from a file, handling potential coordinate name changes.""" + ds = xr.open_dataset(filename) + # Handle potential time coordinate name variations + t = ds['time'].to_numpy() if 'time' in ds else ds['t'].to_numpy() + # Rename coordinates if they exist, ensuring compatibility with older datasets + if 'r_cell' in ds: + ds = ds.rename({ + 'r_cell': 'rho_cell', + 'r_face': 'rho_face', + 'r_cell_norm': 'rho_cell_norm', + 'r_face_norm': 'rho_face_norm', + }) + # Handle potential jext coordinate name variations + if output.CORE_PROFILES_JEXT in ds: + jext = ds[output.CORE_PROFILES_JEXT].to_numpy() + else: + jext = ds['jext'].to_numpy() + return PlotData( + ti=ds[output.TEMP_ION].to_numpy(), + te=ds[output.TEMP_EL].to_numpy(), + ne=ds[output.NE].to_numpy(), + j=ds[output.JTOT].to_numpy(), + johm=ds[output.JOHM].to_numpy(), + j_bootstrap=ds[output.J_BOOTSTRAP].to_numpy(), + jext=jext, + q=ds[output.Q_FACE].to_numpy(), + s=ds[output.S_FACE].to_numpy(), + chi_i=ds[output.CHI_FACE_ION].to_numpy(), + chi_e=ds[output.CHI_FACE_EL].to_numpy(), + rho_cell_coord=ds[output.RHO_CELL_NORM].to_numpy(), + rho_face_coord=ds[output.RHO_FACE_NORM].to_numpy(), + t=t, + ) + + +def plot_run( + plot_config: FigureProperties, outfile: str, outfile2: str | None = None +): """Plots a single run or comparison of two runs.""" - filename1, filename2 = outfile, outfile2 if not path.exists(outfile): raise ValueError(f'File {outfile} does not exist.') if outfile2 is not None and not path.exists(outfile2): raise ValueError(f'File {outfile2} does not exist.') plotdata1 = load_data(outfile) - plotdata2 = None - if outfile2 is not None: - plotdata2 = load_data(outfile2) - - fig, subfigures = create_figure() - ax2 = subfigures[1] - if outfile2 is not None: - ax2.set_title('(1)=' + filename1 + ', (2)=' + filename2) - else: - ax2.set_title('(1)=' + filename1) - - lines1 = get_lines( - plotdata1, - subfigures, + plotdata2 = load_data(outfile2) if outfile2 else None + + # Attribute check. Sufficient to check one PlotData object. + plotdata_attrs = set( + plotdata1.__dataclass_fields__ + ) # Get PlotData attributes + for cfg in plot_config.axes: + for attr in cfg.attrs: + if attr not in plotdata_attrs: + raise ValueError( + f"Attribute '{attr}' in plot_config does not exist in PlotData" + ) + + fig, axes = create_figure(plot_config) + + # Title handling: + title_lines = [f'(1)={outfile}'] + if outfile2: + title_lines.append(f'(2)={outfile2}') + fig.suptitle('\n'.join(title_lines)) + + lines1 = get_lines(plot_config, plotdata1, axes) + lines2 = ( + get_lines(plot_config, plotdata2, axes, comp_plot=True) + if plotdata2 + else None ) - lines2 = None - if plotdata2 is not None: - lines2 = get_lines(plotdata2, subfigures, comp_plot=True) - format_plots(plotdata1, subfigures) + format_plots(plot_config, plotdata1, plotdata2, axes) timeslider = create_slider(plotdata1, plotdata2) fig.canvas.draw() - update = functools.partial( - _update, - plotdata1=plotdata1, - plotdata2=plotdata2, - lines1=lines1, - lines2=lines2, + update = lambda newtime: _update( + newtime, plot_config, plotdata1, lines1, plotdata2, lines2 ) # Call update function when slider value is changed. timeslider.on_changed(update) @@ -111,49 +170,25 @@ def plot_run(outfile: str, outfile2: str | None = None): def _update( newtime, + plot_config: FigureProperties, plotdata1: PlotData, lines1: Sequence[matplotlib.lines.Line2D], plotdata2: PlotData | None = None, lines2: Sequence[matplotlib.lines.Line2D] | None = None, ): """Update plots with new values following slider manipulation.""" - idx = np.abs(plotdata1.t - newtime).argmin() # find index closest to new time - # pytype: disable=attribute-error - datalist1 = [ - plotdata1.chi_i[idx, :], - plotdata1.chi_e[idx, :], - plotdata1.ti[idx, :], - plotdata1.te[idx, :], - plotdata1.ne[idx, :], - plotdata1.j[idx, :], - plotdata1.johm[idx, :], - plotdata1.j_bootstrap[idx, :], - plotdata1.jext[idx, :], - plotdata1.q[idx, :], - plotdata1.s[idx, :], - ] - for plotline1, data1 in zip(lines1, datalist1): - plotline1.set_ydata(data1) - if plotdata2 is not None and lines2 is not None: - idx = np.abs( - plotdata2.t - newtime - ).argmin() # find index closest to new time - datalist2 = [ - plotdata2.chi_i[idx, :], - plotdata2.chi_e[idx, :], - plotdata2.ti[idx, :], - plotdata2.te[idx, :], - plotdata2.ne[idx, :], - plotdata2.j[idx, :], - plotdata2.johm[idx, :], - plotdata2.j_bootstrap[idx, :], - plotdata2.jext[idx, :], - plotdata2.q[idx, :], - plotdata2.s[idx, :], - ] - for plotline2, data2 in zip(lines2, datalist2): - plotline2.set_ydata(data2) - # pytype: enable=attribute-error + + def update_lines(plotdata, lines): + idx = np.abs(plotdata.t - newtime).argmin() + line_idx = 0 + for cfg in plot_config.axes: # Iterate through axes based on plot_config + for attr in cfg.attrs: # Update all lines in current subplot. + lines[line_idx].set_ydata(getattr(plotdata, attr)[idx, :]) + line_idx += 1 + + update_lines(plotdata1, lines1) + if plotdata2 and lines2: + update_lines(plotdata2, lines2) def create_slider( @@ -164,56 +199,94 @@ def create_slider( plt.subplots_adjust(bottom=0.2) axslide = plt.axes([0.12, 0.05, 0.75, 0.05]) - # pytype: disable=attribute-error - if plotdata2 is not None: - dt = min(plotdata1.dt, plotdata2.dt) - else: - dt = plotdata1.dt + tmin = ( + min(plotdata1.t) + if plotdata2 is None + else min(min(plotdata1.t), min(plotdata2.t)) + ) + tmax = ( + max(plotdata1.t) + if plotdata2 is None + else max(max(plotdata1.t), max(plotdata2.t)) + ) + + dt = ( + min(np.diff(plotdata1.t)) + if plotdata2 is None + else min(min(np.diff(plotdata1.t)), min(np.diff(plotdata2.t))) + ) return widgets.Slider( axslide, 'Time [s]', - plotdata1.tmin, - plotdata1.tmax, - valinit=plotdata1.tmin, + tmin, + tmax, + valinit=tmin, valstep=dt, ) -def format_plots(plotdata: PlotData, subfigures: tuple[Any, ...]): +def format_plots( + plot_config: FigureProperties, + plotdata1: PlotData, + plotdata2: PlotData | None, + axes: tuple[Any, ...], +): """Sets up plot formatting.""" - ax1, ax2, ax3, ax4, ax5, ax6 = subfigures - - # pytype: disable=attribute-error - ax1.set_xlabel('Normalized radius') - ax1.set_ylabel(r'Heat conductivity $[m^2/s]$') - ax1.legend() - - ax2.set_ylim([0, plotdata.ymax_t * 1.05]) - ax2.set_xlabel('Normalized radius') - ax2.set_ylabel('Temperature [keV]') - ax2.legend() - - ax3.set_ylim([0, plotdata.ymax_n * 1.05]) - ax3.set_xlabel('Normalized radius') - ax3.set_ylabel(r'Electron density $[10^{20}~m^{-3}]$') - ax3.legend() - ax4.set_ylim([min(plotdata.ymin_j * 1.05, 0), plotdata.ymax_j * 1.05]) - ax4.set_xlabel('Normalized radius') - ax4.set_ylabel(r'Toroidal current $[A~m^{-2}]$') - ax4.legend(fontsize=10) - - ax5.set_ylim([0, plotdata.ymax_q * 1.05]) - ax5.set_xlabel('Normalized radius') - ax5.set_ylabel('Safety factor') - ax5.legend() + # Set default legend fontsize for legends + matplotlib.rc('legend', fontsize=plot_config.default_legend_fontsize) + + def get_limit(plotdata, attrs, percentile, include_first_timepoint): + """Gets the limit for a set of attributes based a histogram percentile.""" + if include_first_timepoint: + values = np.concatenate([getattr(plotdata, attr) for attr in attrs]) + else: + values = np.concatenate( + [getattr(plotdata, attr)[1:, :] for attr in attrs] + ) + return np.percentile(values, percentile) + + for ax, cfg in zip(axes, plot_config.axes): + ax.set_xlabel('Normalized radius') + ax.set_ylabel(cfg.ylabel) + + # Get limits for y-axis based on percentile values. + # 0.0 or 100.0 are special cases for simple min/max values. + ymin = get_limit( + plotdata1, cfg.attrs, cfg.lower_percentile, cfg.include_first_timepoint + ) + ymax = get_limit( + plotdata1, cfg.attrs, cfg.upper_percentile, cfg.include_first_timepoint + ) - ax6.set_ylim([min(plotdata.ymin_s * 1.05, 0), plotdata.ymax_s * 1.05]) - ax6.set_xlabel('Normalized radius') - ax6.set_ylabel('Magnetic shear') - ax6.legend() - # pytype: enable=attribute-error + if plotdata2: + ymin = min( + ymin, + get_limit( + plotdata2, + cfg.attrs, + cfg.lower_percentile, + cfg.include_first_timepoint, + ), + ) + ymax = max( + ymax, + get_limit( + plotdata2, + cfg.attrs, + cfg.upper_percentile, + cfg.include_first_timepoint, + ), + ) + + lower_bound = ymin / 1.05 if ymin > 0 else ymin * 1.05 + if cfg.ylim_min_zero: + ax.set_ylim([min(lower_bound, 0), ymax * 1.05]) + else: + ax.set_ylim([lower_bound, ymax * 1.05]) + + ax.legend(fontsize=cfg.legend_fontsize) def get_rho( @@ -233,146 +306,52 @@ def get_rho( def get_lines( + plot_config: FigureProperties, plotdata: PlotData, - subfigures: tuple[Any, ...], + axes: tuple[Any, ...], comp_plot: bool = False, ): """Gets lines for all plots.""" lines = [] # If comparison, first lines labeled (1) and solid, second set (2) and dashed. - if not comp_plot: - suffix = '~(1)' - dashed = '' - else: - suffix = '~(2)' - dashed = '--' - - ax1, ax2, ax3, ax4, ax5, ax6 = subfigures - - (line,) = ax1.plot( - get_rho(plotdata, 'chi_i'), - plotdata.chi_i[1, :], - 'r' + dashed, - label=rf'$\chi_i{suffix}$', - ) - lines.append(line) - (line,) = ax1.plot( - get_rho(plotdata, 'chi_e'), - plotdata.chi_e[1, :], - 'b' + dashed, - label=rf'$\chi_e{suffix}$', - ) - lines.append(line) - (line,) = ax2.plot( - get_rho(plotdata, 'ti'), - plotdata.ti[0, :], - 'r' + dashed, - label=rf'$T_i{suffix}$', - ) - lines.append(line) - (line,) = ax2.plot( - get_rho(plotdata, 'te'), - plotdata.te[0, :], - 'b' + dashed, - label=rf'$T_e{suffix}$', - ) - lines.append(line) - (line,) = ax3.plot( - get_rho(plotdata, 'ne'), - plotdata.ne[0, :], - 'r' + dashed, - label=rf'$n_e{suffix}$', - ) - lines.append(line) - (line,) = ax4.plot( - get_rho(plotdata, 'j'), - plotdata.j[0, :], - 'r' + dashed, - label=rf'$j_{{tot}}{suffix}$', - ) - lines.append(line) - (line,) = ax4.plot( - get_rho(plotdata, 'johm'), - plotdata.johm[0, :], - 'b' + dashed, - label=rf'$j_{{ohm}}{suffix}$', - ) - lines.append(line) - (line,) = ax4.plot( - get_rho(plotdata, 'j_bootstrap'), - plotdata.j_bootstrap[0, :], - 'g' + dashed, - label=rf'$j_{{bs}}{suffix}$', - ) - lines.append(line) - (line,) = ax4.plot( - get_rho(plotdata, 'jext'), - plotdata.jext[0, :], - 'm' + dashed, - label=rf'$j_{{ext}}{suffix}$', - ) - lines.append(line) - (line,) = ax5.plot( - get_rho(plotdata, 'q'), - plotdata.q[0, :], - 'r' + dashed, - label=rf'$q{suffix}$', - ) - lines.append(line) - (line,) = ax6.plot( - get_rho(plotdata, 's'), - plotdata.s[0, :], - 'r' + dashed, - label=rf'$\hat{{s}}{suffix}$', - ) - lines.append(line) + suffix = f' ({1 if not comp_plot else 2})' + dashed = '--' if comp_plot else '' + + for ax, cfg in zip(axes, plot_config.axes): + line_idx = 0 # Reset color selection cycling for each plot. + for attr, label in zip(cfg.attrs, cfg.labels): + rho = get_rho(plotdata, attr) + (line,) = ax.plot( + rho, + getattr(plotdata, attr)[0, :], # Plot data at time zero + plot_config.colors[line_idx % len(plot_config.colors)] + dashed, + label=f'{label}{suffix}', + ) + lines.append(line) + line_idx += 1 return lines -def load_data(filename: str) -> PlotData: - """Loads an xr.Dataset from a file, handling potential coordinate name changes.""" - ds = xr.open_dataset(filename) - # Handle potential time coordinate name variations - t = ds['time'].to_numpy() if 'time' in ds else ds['t'].to_numpy() - # Rename coordinates if they exist, ensuring compatibility with older datasets - if 'r_cell' in ds: - ds = ds.rename({ - 'r_cell': 'rho_cell', - 'r_face': 'rho_face', - 'r_cell_norm': 'rho_cell_norm', - 'r_face_norm': 'rho_face_norm', - }) - # Handle potential jext coordinate name variations - if output.CORE_PROFILES_JEXT in ds: - jext = ds[output.CORE_PROFILES_JEXT].to_numpy() - else: - jext = ds['jext'].to_numpy() - return PlotData( - ti=ds[output.TEMP_ION].to_numpy(), - te=ds[output.TEMP_EL].to_numpy(), - ne=ds[output.NE].to_numpy(), - j=ds[output.JTOT].to_numpy(), - johm=ds[output.JOHM].to_numpy(), - j_bootstrap=ds[output.J_BOOTSTRAP].to_numpy(), - jext=jext, - q=ds[output.Q_FACE].to_numpy(), - s=ds[output.S_FACE].to_numpy(), - chi_i=ds[output.CHI_FACE_ION].to_numpy(), - chi_e=ds[output.CHI_FACE_EL].to_numpy(), - rho_cell_coord=ds[output.RHO_CELL_NORM].to_numpy(), - rho_face_coord=ds[output.RHO_FACE_NORM].to_numpy(), - t=t, +def create_figure(plot_config: FigureProperties): + """Creates the figure and axes.""" + rows = plot_config.rows + cols = plot_config.cols + figsize = ( + cols * plot_config.figure_size_factor, + rows * plot_config.figure_size_factor, ) - - -def create_figure(): - fig = plt.figure(figsize=(15, 10)) - ax1 = fig.add_subplot(231) - ax2 = fig.add_subplot(232) - ax3 = fig.add_subplot(233) - ax4 = fig.add_subplot(234) - ax5 = fig.add_subplot(235) - ax6 = fig.add_subplot(236) - subfigures = (ax1, ax2, ax3, ax4, ax5, ax6) - return fig, subfigures + fig, axes = plt.subplots(rows, cols, figsize=figsize) + # Flatten axes array if necessary (for consistent indexing) + if isinstance( + axes, np.ndarray + ): # Check if it's a NumPy array before flattening + axes = axes.flatten() + elif rows > 1 or cols > 1: # This shouldn't happen, but added as a safety net + raise ValueError( + f'Axes is not a numpy array, but should be one since rows={rows},' + f' cols={cols}' + ) + else: + axes = [axes] # Make axes iterable if only one subplot + return fig, axes