Skip to content

Commit

Permalink
Refactor plotruns_lib in preparation for plotting extensions.
Browse files Browse the repository at this point in the history
1. Streamline plotruns_lib for more abstraction. Specify data and formatting details in new PLOT_CONFIG object, which can be extended without needing to touch lib methods.

2. Remove redundant post_init for PlotData

3. Various cleanups, e.g.
(i) cleaner figure title handling
(ii) in comparison plots, the max and min are taken from the combined plots, as opposed to only the first dataset provided
(iii) attempt to set a better default ymax for the volatile chi plots

Correctness verified by eye through direct comparison with the legacy version.

PiperOrigin-RevId: 682055943
  • Loading branch information
jcitrin authored and Torax team committed Oct 4, 2024
1 parent 4352bba commit 248b7b6
Show file tree
Hide file tree
Showing 4 changed files with 335 additions and 270 deletions.
31 changes: 23 additions & 8 deletions run_simulation_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,12 @@
'If provided, overrides the default output directory.',
)

_PLOT_CONFIG_MODULE = flags.DEFINE_string(
'plot_config',
'torax.plotting.configs.default_plot_config', # Default
'Module path to the plot config.',
)

jax.config.parse_flags_with_absl()


Expand Down Expand Up @@ -254,10 +260,8 @@ def change_config(
new_runtime_params = build_sim.build_runtime_params_from_config(
sim_config['runtime_params']
)
new_geo_provider = (
build_sim.build_geometry_provider_from_config(
sim_config['geometry'],
)
new_geo_provider = build_sim.build_geometry_provider_from_config(
sim_config['geometry'],
)
new_transport_model_builder = (
build_sim.build_transport_model_builder_from_config(
Expand Down Expand Up @@ -426,25 +430,36 @@ def _post_run_plotting(
color=simulation_app.AnsiColors.RED,
)
return
try:
plot_config = config_loader.import_module(
_PLOT_CONFIG_MODULE.value
).PLOT_CONFIG
except (ModuleNotFoundError, AttributeError) as e:
logging.exception(
'Error loading plot config module %s: %s', _PLOT_CONFIG_MODULE.value, e
)
return
match input_text:
case '0':
return plotruns_lib.plot_run(output_files[-1])
return plotruns_lib.plot_run(plot_config, output_files[-1])
case '1':
if len(output_files) == 1:
simulation_app.log_to_stdout(
'Only one output run file found, only plotting the last run.',
color=simulation_app.AnsiColors.RED,
)
return plotruns_lib.plot_run(output_files[-1])
return plotruns_lib.plot_run(output_files[-1], output_files[-2])
return plotruns_lib.plot_run(plot_config, output_files[-1])
return plotruns_lib.plot_run(
plot_config, output_files[-1], output_files[-2]
)
case '2':
reference_run = _REFERENCE_RUN.value
if reference_run is None:
simulation_app.log_to_stdout(
'No reference run provided, only plotting the last run.',
color=simulation_app.AnsiColors.RED,
)
return plotruns_lib.plot_run(output_files[-1], reference_run)
return plotruns_lib.plot_run(plot_config, output_files[-1], reference_run)
case _:
raise ValueError('Unknown command')

Expand Down
61 changes: 61 additions & 0 deletions torax/plotting/configs/default_plot_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright 2024 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Default plotting configuration for Torax runs."""

from torax.plotting import plotruns_lib

PLOT_CONFIG = plotruns_lib.FigureProperties(
rows=2,
cols=3,
axes=(
# For chi, set histogram percentile for y-axis upper limit, due to
# volatile nature of the data. Do not include first timepoint since
# chi is defined as zero there and may unduly affect ylim.
plotruns_lib.PlotProperties(
attrs=('chi_i', 'chi_e'),
labels=(r'$\chi_i$', r'$\chi_e$'),
ylabel=r'Heat conductivity $[m^2/s]$',
upper_percentile=98.0,
include_first_timepoint=False,
ylim_min_zero=False,
),
plotruns_lib.PlotProperties(
attrs=('ti', 'te'),
labels=(r'$T_i$', r'$T_e$'),
ylabel='Temperature [keV]',
),
plotruns_lib.PlotProperties(
attrs=('ne',),
labels=(r'$n_e$',),
ylabel=r'Electron density $[10^{20}~m^{-3}]$',
),
plotruns_lib.PlotProperties(
attrs=('j', 'johm', 'j_bootstrap', 'jext'),
labels=(r'$j_{tot}$', r'$j_{ohm}$', r'$j_{bs}$', r'$j_{ext}$'),
ylabel=r'Toroidal current $[A~m^{-2}]$',
legend_fontsize=8, # Smaller fontsize for this plot
),
plotruns_lib.PlotProperties(
attrs=('q',),
labels=(r'$q$',),
ylabel='Safety factor',
),
plotruns_lib.PlotProperties(
attrs=('s',),
labels=(r'$\hat{s}$',),
ylabel='Magnetic shear',
),
),
)
30 changes: 20 additions & 10 deletions torax/plotting/plotruns.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,11 @@
Includes a time slider. Reads output files with xarray data or legacy h5 data.
Plots:
(1) chi_i, chi_e (transport coefficients)
(2) Ti, Te (temperatures)
(3) ne (density)
(4) jtot, johm (total and ohmic plasma current)
(5) q (safety factor)
(6) s (magnetic shear)
Plots are configured by a plot_config module.
"""
import importlib
from absl import app
from absl import logging
from absl.flags import argparse_flags
import matplotlib
from torax.plotting import plotruns_lib
Expand All @@ -34,6 +30,7 @@


def parse_flags(_):
"""Parse flags for the plotting tool."""
parser = argparse_flags.ArgumentParser(description='Plot finished run')
parser.add_argument(
'--outfile',
Expand All @@ -43,15 +40,28 @@ def parse_flags(_):
' comparison is done)'
),
)
parser.set_defaults(normalized=True)
parser.add_argument(
'--plot_config',
default='torax.plotting.configs.default_plot_config',
help='Name of the plot config module.',
)
return parser.parse_args()


def main(args):
plot_config_module_path = args.plot_config
try:
plot_config_module = importlib.import_module(plot_config_module_path)
plot_config = plot_config_module.PLOT_CONFIG
except (ModuleNotFoundError, AttributeError) as e:
logging.exception(
'Error loading plot config: %s: %s', plot_config_module_path, e
)
raise
if len(args.outfile) == 1:
plotruns_lib.plot_run(args.outfile[0])
plotruns_lib.plot_run(plot_config, args.outfile[0])
else:
plotruns_lib.plot_run(args.outfile[0], args.outfile[1])
plotruns_lib.plot_run(plot_config, args.outfile[0], args.outfile[1])


if __name__ == '__main__':
Expand Down
Loading

0 comments on commit 248b7b6

Please sign in to comment.