diff --git a/torax/plotting/configs/default_plot_config.py b/torax/plotting/configs/default_plot_config.py index 5d641a9e..0b2ff1f4 100644 --- a/torax/plotting/configs/default_plot_config.py +++ b/torax/plotting/configs/default_plot_config.py @@ -17,35 +17,66 @@ from torax.plotting import plotruns_lib PLOT_CONFIG = plotruns_lib.FigureProperties( - rows=2, - cols=3, + rows=3, + cols=5, + tick_fontsize=8, + axes_fontsize=8, + default_legend_fontsize=7, + figure_size_factor=5, + title_fontsize=12, axes=( # For chi, set histogram percentile for y-axis upper limit, due to # volatile nature of the data. Do not include first timepoint since # chi is defined as zero there and may unduly affect ylim. + plotruns_lib.PlotProperties( + attrs=('ti', 'te'), + labels=(r'$T_\mathrm{i}$', r'$T_\mathrm{e}$'), + ylabel='Temperature [keV]', + ), + plotruns_lib.PlotProperties( + attrs=('ne',), + labels=(r'$n_\mathrm{e}$',), + ylabel=r'Electron density $[10^{20}~m^{-3}]$', + ), plotruns_lib.PlotProperties( attrs=('chi_i', 'chi_e'), - labels=(r'$\chi_i$', r'$\chi_e$'), + labels=(r'$\chi_\mathrm{i}$', r'$\chi_\mathrm{e}$'), ylabel=r'Heat conductivity $[m^2/s]$', upper_percentile=98.0, include_first_timepoint=False, ylim_min_zero=False, ), plotruns_lib.PlotProperties( - attrs=('ti', 'te'), - labels=(r'$T_i$', r'$T_e$'), - ylabel='Temperature [keV]', + attrs=('d_e', 'v_e'), + labels=(r'$D_\mathrm{e}$', r'$V_\mathrm{e}$'), + ylabel=r'Diff $[m^2/s]$ or Conv $[m/s]$', + upper_percentile=98.0, + lower_percentile=2.0, + include_first_timepoint=False, + ylim_min_zero=False, ), plotruns_lib.PlotProperties( - attrs=('ne',), - labels=(r'$n_e$',), - ylabel=r'Electron density $[10^{20}~m^{-3}]$', + plot_type=plotruns_lib.PlotType.TIME_SERIES, + attrs=('i_total', 'i_bootstrap'), + labels=(r'$I_\mathrm{p}$', r'$I_\mathrm{bs}$'), + ylabel=r'Current [MA]', + ), + plotruns_lib.PlotProperties( + attrs=('psi',), + labels=(r'$\psi$',), + ylabel=r'Poloidal flux [Wb]', ), plotruns_lib.PlotProperties( attrs=('j', 'johm', 'j_bootstrap', 'jext'), - labels=(r'$j_{tot}$', r'$j_{ohm}$', r'$j_{bs}$', r'$j_{ext}$'), + labels=( + r'$j_\mathrm{tot}$', + r'$j_\mathrm{ohm}$', + r'$j_\mathrm{bs}$', + r'$j_\mathrm{ext}$', + ), ylabel=r'Toroidal current $[A~m^{-2}]$', - legend_fontsize=8, # Smaller fontsize for this plot + legend_fontsize=7, # Smaller fontsize for this plot + suppress_zero_values=True, # Do not plot all-zero data ), plotruns_lib.PlotProperties( attrs=('q',), @@ -57,5 +88,71 @@ labels=(r'$\hat{s}$',), ylabel='Magnetic shear', ), + plotruns_lib.PlotProperties( + plot_type=plotruns_lib.PlotType.TIME_SERIES, + attrs=('Q_fusion',), + labels=(r'$Q_\mathrm{fusion}$',), + ylabel='Fusion gain', + ), + plotruns_lib.PlotProperties( + attrs=('psidot',), + labels=(r'$\dot{\psi}$',), + ylabel='Loop voltage', + upper_percentile=98.0, + ), + plotruns_lib.PlotProperties( + attrs=( + 'q_icrh_i', + 'q_icrh_e', + 'q_nbi_i', + 'q_nbi_e', + 'q_ecrh', + 'q_gen_i', + 'q_gen_e', + ), + labels=( + r'$Q_\mathrm{ICRH,i}$', + r'$Q_\mathrm{ICRH,e}$', + r'$Q_\mathrm{NBI,i}$', + r'$Q_\mathrm{NBI,e}$', + r'$Q_\mathrm{ERCH}$', + r'$Q_\mathrm{generic,i}$', + r'$Q_\mathrm{generic,e}$', + ), + ylabel=r'External heat source density $[W~m^{-3}]$', + legend_fontsize=7, # Smaller fontsize for this plot + suppress_zero_values=True, # Do not plot all-zero data + ), + plotruns_lib.PlotProperties( + attrs=('q_alpha_i', 'q_alpha_e', 'q_ohmic', 'q_ei'), + labels=( + r'$Q_\mathrm{alpha,i}$', + r'$Q_\mathrm{alpha,e}$', + r'$Q_\mathrm{ohmic}$', + r'$Q_\mathrm{ei}$', + ), + ylabel=r'Internal heat source density $[W~m^{-3}]$', + legend_fontsize=6, # Smaller fontsize for this plot + suppress_zero_values=True, # Do not plot all-zero data + ), + plotruns_lib.PlotProperties( + attrs=('q_brems',), + labels=(r'$Q_\mathrm{brems}$',), + ylabel=r'Heat sink density $[kW~m^{-3}]$', + suppress_zero_values=True, # Do not plot all-zero data + ), + plotruns_lib.PlotProperties( + plot_type=plotruns_lib.PlotType.TIME_SERIES, + attrs=('p_auxiliary', 'p_ohmic', 'p_alpha', 'p_sink'), + labels=( + r'$P_\mathrm{aux}$', + r'$P_\mathrm{ohm}$', + r'$P_\mathrm{\alpha}$', + r'$P_\mathrm{sink}$', + ), + ylabel=r'Total heating/sink powers $[MW]$', + legend_fontsize=6, # Smaller fontsize for this plot + suppress_zero_values=True, # Do not plot all-zero data + ), ), ) diff --git a/torax/plotting/configs/simple_plot_config.py b/torax/plotting/configs/simple_plot_config.py new file mode 100644 index 00000000..5d641a9e --- /dev/null +++ b/torax/plotting/configs/simple_plot_config.py @@ -0,0 +1,61 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Default plotting configuration for Torax runs.""" + +from torax.plotting import plotruns_lib + +PLOT_CONFIG = plotruns_lib.FigureProperties( + rows=2, + cols=3, + axes=( + # For chi, set histogram percentile for y-axis upper limit, due to + # volatile nature of the data. Do not include first timepoint since + # chi is defined as zero there and may unduly affect ylim. + plotruns_lib.PlotProperties( + attrs=('chi_i', 'chi_e'), + labels=(r'$\chi_i$', r'$\chi_e$'), + ylabel=r'Heat conductivity $[m^2/s]$', + upper_percentile=98.0, + include_first_timepoint=False, + ylim_min_zero=False, + ), + plotruns_lib.PlotProperties( + attrs=('ti', 'te'), + labels=(r'$T_i$', r'$T_e$'), + ylabel='Temperature [keV]', + ), + plotruns_lib.PlotProperties( + attrs=('ne',), + labels=(r'$n_e$',), + ylabel=r'Electron density $[10^{20}~m^{-3}]$', + ), + plotruns_lib.PlotProperties( + attrs=('j', 'johm', 'j_bootstrap', 'jext'), + labels=(r'$j_{tot}$', r'$j_{ohm}$', r'$j_{bs}$', r'$j_{ext}$'), + ylabel=r'Toroidal current $[A~m^{-2}]$', + legend_fontsize=8, # Smaller fontsize for this plot + ), + plotruns_lib.PlotProperties( + attrs=('q',), + labels=(r'$q$',), + ylabel='Safety factor', + ), + plotruns_lib.PlotProperties( + attrs=('s',), + labels=(r'$\hat{s}$',), + ylabel='Magnetic shear', + ), + ), +) diff --git a/torax/plotting/plotruns_lib.py b/torax/plotting/plotruns_lib.py index dd79ef0a..a25a5774 100644 --- a/torax/plotting/plotruns_lib.py +++ b/torax/plotting/plotruns_lib.py @@ -16,10 +16,12 @@ from collections.abc import Sequence import dataclasses +import enum from os import path -from typing import Any +from typing import Any, List import matplotlib +from matplotlib import gridspec from matplotlib import widgets import matplotlib.pyplot as plt import numpy as np @@ -31,6 +33,11 @@ # first ascending in columns, then rows. +class PlotType(enum.Enum): + SPATIAL = 1 + TIME_SERIES = 2 + + @dataclasses.dataclass class PlotProperties: """Dataclass for individual plot properties.""" @@ -43,6 +50,8 @@ class PlotProperties: lower_percentile: float = 0.0 include_first_timepoint: bool = True ylim_min_zero: bool = True + plot_type: PlotType = PlotType.SPATIAL + suppress_zero_values: bool = False # If True, all-zero-data is not plotted @dataclasses.dataclass @@ -53,6 +62,9 @@ class FigureProperties: cols: int axes: tuple[PlotProperties, ...] figure_size_factor: float = 5 + tick_fontsize: int = 10 + axes_fontsize: int = 10 + title_fontsize: int = 16 default_legend_fontsize: int = 10 colors: tuple[str, ...] = ('r', 'b', 'g', 'm', 'y', 'c') @@ -68,6 +80,8 @@ class PlotData: ti: np.ndarray te: np.ndarray ne: np.ndarray + psi: np.ndarray + psidot: np.ndarray j: np.ndarray johm: np.ndarray j_bootstrap: np.ndarray @@ -76,6 +90,31 @@ class PlotData: s: np.ndarray chi_i: np.ndarray chi_e: np.ndarray + d_e: np.ndarray + v_e: np.ndarray + q_icrh_i: np.ndarray + q_icrh_e: np.ndarray + q_nbi_i: np.ndarray + q_nbi_e: np.ndarray + q_nbi_e: np.ndarray + q_gen_i: np.ndarray + q_gen_e: np.ndarray + q_ecrh: np.ndarray + q_alpha_i: np.ndarray + q_alpha_e: np.ndarray + q_ohmic: np.ndarray + q_brems: np.ndarray + q_ei: np.ndarray + Q_fusion: np.ndarray # pylint: disable=invalid-name + s_puff: np.ndarray + s_nbi: np.ndarray + s_pellet: np.ndarray + i_total: np.ndarray + i_bootstrap: np.ndarray + p_auxiliary: np.ndarray + p_ohmic: np.ndarray + p_alpha: np.ndarray + p_sink: np.ndarray t: np.ndarray rho_cell_coord: np.ndarray rho_face_coord: np.ndarray @@ -99,10 +138,27 @@ def load_data(filename: str) -> PlotData: jext = ds[output.CORE_PROFILES_JEXT].to_numpy() else: jext = ds['jext'].to_numpy() + + def get_optional_data(ds, key, grid_type): + if grid_type.lower() not in ['cell', 'face']: + raise ValueError( + f'grid_type for {key}must be either "cell" or "face", got {grid_type}' + ) + if key in ds: + return ds[key].to_numpy() + else: + return ( + np.zeros_like(ds[output.TEMP_ION].to_numpy()) + if grid_type == 'cell' + else np.zeros_like(ds[output.CHI_FACE_ION].to_numpy()) + ) + return PlotData( ti=ds[output.TEMP_ION].to_numpy(), te=ds[output.TEMP_EL].to_numpy(), ne=ds[output.NE].to_numpy(), + psi=ds[output.PSI].to_numpy(), + psidot=ds[output.PSIDOT].to_numpy(), j=ds[output.JTOT].to_numpy(), johm=ds[output.JOHM].to_numpy(), j_bootstrap=ds[output.J_BOOTSTRAP].to_numpy(), @@ -111,8 +167,32 @@ def load_data(filename: str) -> PlotData: s=ds[output.S_FACE].to_numpy(), chi_i=ds[output.CHI_FACE_ION].to_numpy(), chi_e=ds[output.CHI_FACE_EL].to_numpy(), + d_e=ds[output.D_FACE_EL].to_numpy(), + v_e=ds[output.V_FACE_EL].to_numpy(), rho_cell_coord=ds[output.RHO_CELL_NORM].to_numpy(), rho_face_coord=ds[output.RHO_FACE_NORM].to_numpy(), + q_icrh_i=get_optional_data(ds, 'icrh_heat_source_ion', 'cell'), + q_icrh_e=get_optional_data(ds, 'icrh_heat_source_el', 'cell'), + q_nbi_i=get_optional_data(ds, 'nbi_heat_source_ion', 'cell'), + q_nbi_e=get_optional_data(ds, 'nbi_heat_source_el', 'cell'), + q_gen_i=get_optional_data(ds, 'generic_ion_el_heat_source_ion', 'cell'), + q_gen_e=get_optional_data(ds, 'generic_ion_el_heat_source_el', 'cell'), + q_ecrh=get_optional_data(ds, 'ecrh_heat_source', 'cell'), + q_alpha_i=get_optional_data(ds, 'fusion_heat_source_ion', 'cell'), + q_alpha_e=get_optional_data(ds, 'fusion_heat_source_el', 'cell'), + q_ohmic=get_optional_data(ds, 'ohmic_heat_source', 'cell'), + q_brems=get_optional_data(ds, 'bremsstrahlung_heat_sink', 'cell'), + q_ei=ds['qei_source'].to_numpy(), # ion heating/sink + Q_fusion=ds['Q_fusion'].to_numpy(), # pylint: disable=invalid-name + s_puff=get_optional_data(ds, 'gas_puff_source', 'cell'), + s_nbi=get_optional_data(ds, 'nbi_particle_source', 'cell'), + s_pellet=get_optional_data(ds, 'pellet_source', 'cell'), + i_total=ds[output.IP].to_numpy(), + i_bootstrap=ds[output.I_BOOTSTRAP].to_numpy() / 1e6, + p_ohmic=ds['P_ohmic'].to_numpy()/1e6, + p_auxiliary=(ds['P_external_tot']-ds['P_ohmic']).to_numpy()/1e6, + p_alpha=ds['P_alpha_tot'].to_numpy()/1e6, + p_sink=ds['P_brems'].to_numpy()/1e6, t=t, ) @@ -139,7 +219,7 @@ def plot_run( f"Attribute '{attr}' in plot_config does not exist in PlotData" ) - fig, axes = create_figure(plot_config) + fig, axes, slider_ax = create_figure(plot_config) # Title handling: title_lines = [f'(1)={outfile}'] @@ -155,17 +235,19 @@ def plot_run( ) format_plots(plot_config, plotdata1, plotdata2, axes) - timeslider = create_slider(plotdata1, plotdata2) + timeslider = create_slider(slider_ax, plotdata1, plotdata2) fig.canvas.draw() - update = lambda newtime: _update( - newtime, plot_config, plotdata1, lines1, plotdata2, lines2 - ) - # Call update function when slider value is changed. + def update(newtime): + """Update plots with new values following slider manipulation.""" + fig.constrained_layout = False + _update(newtime, plot_config, plotdata1, lines1, plotdata2, lines2) + fig.constrained_layout = True + fig.canvas.draw_idle() + timeslider.on_changed(update) fig.canvas.draw() plt.show() - fig.tight_layout() def _update( @@ -182,8 +264,13 @@ def update_lines(plotdata, lines): idx = np.abs(plotdata.t - newtime).argmin() line_idx = 0 for cfg in plot_config.axes: # Iterate through axes based on plot_config + if cfg.plot_type == PlotType.TIME_SERIES: + continue # Time series plots do not need to be updated for attr in cfg.attrs: # Update all lines in current subplot. - lines[line_idx].set_ydata(getattr(plotdata, attr)[idx, :]) + data = getattr(plotdata, attr) + if cfg.suppress_zero_values and np.all(data == 0): + continue + lines[line_idx].set_ydata(data[idx, :]) line_idx += 1 update_lines(plotdata1, lines1) @@ -192,13 +279,11 @@ def update_lines(plotdata, lines): def create_slider( + ax: matplotlib.axes.Axes, plotdata1: PlotData, plotdata2: PlotData | None = None, ) -> widgets.Slider: """Create a slider tool for the plot.""" - plt.subplots_adjust(bottom=0.2) - axslide = plt.axes([0.12, 0.05, 0.75, 0.05]) - tmin = ( min(plotdata1.t) if plotdata2 is None @@ -217,7 +302,7 @@ def create_slider( ) return widgets.Slider( - axslide, + ax, 'Time [s]', tmin, tmax, @@ -230,7 +315,7 @@ def format_plots( plot_config: FigureProperties, plotdata1: PlotData, plotdata2: PlotData | None, - axes: tuple[Any, ...], + axes: List[Any], ): """Sets up plot formatting.""" @@ -248,7 +333,12 @@ def get_limit(plotdata, attrs, percentile, include_first_timepoint): return np.percentile(values, percentile) for ax, cfg in zip(axes, plot_config.axes): - ax.set_xlabel('Normalized radius') + if cfg.plot_type == PlotType.SPATIAL: + ax.set_xlabel('Normalized radius') + elif cfg.plot_type == PlotType.TIME_SERIES: + ax.set_xlabel('Time [s]') + else: + raise ValueError(f'Unknown plot type: {cfg.plot_type}') ax.set_ylabel(cfg.ylabel) # Get limits for y-axis based on percentile values. @@ -308,7 +398,7 @@ def get_rho( def get_lines( plot_config: FigureProperties, plotdata: PlotData, - axes: tuple[Any, ...], + axes: List[Any], comp_plot: bool = False, ): """Gets lines for all plots.""" @@ -319,16 +409,35 @@ def get_lines( for ax, cfg in zip(axes, plot_config.axes): line_idx = 0 # Reset color selection cycling for each plot. - for attr, label in zip(cfg.attrs, cfg.labels): - rho = get_rho(plotdata, attr) - (line,) = ax.plot( - rho, - getattr(plotdata, attr)[0, :], # Plot data at time zero - plot_config.colors[line_idx % len(plot_config.colors)] + dashed, - label=f'{label}{suffix}', - ) - lines.append(line) - line_idx += 1 + if cfg.plot_type == PlotType.SPATIAL: + for attr, label in zip(cfg.attrs, cfg.labels): + data = getattr(plotdata, attr) + if cfg.suppress_zero_values and np.all(data == 0): + continue + rho = get_rho(plotdata, attr) + (line,) = ax.plot( + rho, + data[0, :], # Plot data at time zero + plot_config.colors[line_idx % len(plot_config.colors)] + dashed, + label=f'{label}{suffix}', + ) + lines.append(line) + line_idx += 1 + elif cfg.plot_type == PlotType.TIME_SERIES: + for attr, label in zip(cfg.attrs, cfg.labels): + data = getattr(plotdata, attr) + if cfg.suppress_zero_values and np.all(data == 0): + continue + # No need to return a line since this will not need to be updated. + _ = ax.plot( + plotdata.t, + data, # Plot entire time series + plot_config.colors[line_idx % len(plot_config.colors)] + dashed, + label=f'{label}{suffix}', + ) + line_idx += 1 + else: + raise ValueError(f'Unknown plot type: {cfg.plot_type}') return lines @@ -337,21 +446,27 @@ def create_figure(plot_config: FigureProperties): """Creates the figure and axes.""" rows = plot_config.rows cols = plot_config.cols - figsize = ( - cols * plot_config.figure_size_factor, - rows * plot_config.figure_size_factor, + matplotlib.rc('xtick', labelsize=plot_config.tick_fontsize) + matplotlib.rc('ytick', labelsize=plot_config.tick_fontsize) + matplotlib.rc('axes', labelsize=plot_config.axes_fontsize) + matplotlib.rc('figure', titlesize=plot_config.title_fontsize) + fig = plt.figure( + figsize=( + cols * plot_config.figure_size_factor, + rows * plot_config.figure_size_factor, + ), + constrained_layout=True, ) - fig, axes = plt.subplots(rows, cols, figsize=figsize) - # Flatten axes array if necessary (for consistent indexing) - if isinstance( - axes, np.ndarray - ): # Check if it's a NumPy array before flattening - axes = axes.flatten() - elif rows > 1 or cols > 1: # This shouldn't happen, but added as a safety net - raise ValueError( - f'Axes is not a numpy array, but should be one since rows={rows},' - f' cols={cols}' - ) - else: - axes = [axes] # Make axes iterable if only one subplot - return fig, axes + # Create the GridSpec - leave space for the slider at the bottom + gs = gridspec.GridSpec( + rows + 1, cols, figure=fig, height_ratios=[1] * rows + [0.2] + ) # Adjust 0.2 for slider height + + axes = [] + for i in range(rows * cols): + row = i // cols + col = i % cols + axes.append(fig.add_subplot(gs[row, col])) # Add subplots to the grid + # slider spans all columns in the last row + slider_ax = fig.add_subplot(gs[rows, :]) + return fig, axes, slider_ax