Skip to content

Commit

Permalink
Merge plotting tool and h5 output improvements
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 615800037
  • Loading branch information
Torax team committed Mar 14, 2024
1 parent 0954f55 commit 5634d3c
Show file tree
Hide file tree
Showing 10 changed files with 156 additions and 65 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _parse_requirements(path):
version=_get_version(),
license='Apache 2.0',
author='Google DeepMind',
description='TORAX', # TODO(b/323504363): Add desc.
description='Differentiable 1D tokamak plasma transport simulator in JAX',
long_description=open(
os.path.join(_CURRENT_DIR, 'README.md')
).read(),
Expand Down
1 change: 1 addition & 0 deletions torax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from torax.sim import build_sim_from_config
from torax.sim import run_simulation
from torax.sim import Sim
from torax.state import AuxOutput
from torax.state import State
from torax.stepper.stepper import Stepper
from torax.time_step_calculator.chi_time_step_calculator import ChiTimeStepCalculator
Expand Down
154 changes: 107 additions & 47 deletions torax/plotruns.py → torax/plotting/plotruns.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,12 @@ class PlotData:
ne: np.ndarray
j: np.ndarray
johm: np.ndarray
j_bootstrap: np.ndarray
jext: np.ndarray
q: np.ndarray
s: np.ndarray
chi_i: np.ndarray
chi_e: np.ndarray
t: np.ndarray
rcell_coord: np.ndarray
rface_coord: np.ndarray
Expand All @@ -62,6 +66,9 @@ def __post_init__(self):
self.ymax_q = np.amax(self.q)
self.ymax_s = np.amax(self.s)
self.ymin_s = np.amin(self.s)
# avoid initial condition for chi ymax, since can be unphysically high
self.ymax_chi_i = np.amax(self.chi_i[1:, :])
self.ymax_chi_e = np.amax(self.chi_e[1:, :])
self.dt = min(np.diff(self.t))


Expand All @@ -87,30 +94,38 @@ def __post_init__(self):
if not args.outfile:
raise ValueError('No output file provided')

with h5py.File(args.outfile[0] + 'state_history.h5', 'r') as hf:
plotdata = PlotData(
with h5py.File(args.outfile[0] + '.h5', 'r') as hf:
plotdata1 = PlotData(
ti=hf['temp_ion'][:],
te=hf['temp_el'][:],
ne=hf['ne'][:],
j=hf['jtot'][:],
johm=hf['johm'][:],
j_bootstrap=hf['j_bootstrap'][:],
jext=hf['jext'][:],
q=hf['q_face'][:],
s=hf['s_face'][:],
chi_i=hf['chi_face_ion'][:],
chi_e=hf['chi_face_el'][:],
t=hf['t'][:],
rcell_coord=hf['r_cell_norm'][:],
rface_coord=hf['r_face_norm'][:],
)

if comp_plot:
with h5py.File(args.outfile[1] + 'state_history.h5', 'r') as hf:
with h5py.File(args.outfile[1] + '.h5', 'r') as hf:
plotdata2 = PlotData(
ti=hf['temp_ion'][:],
te=hf['temp_el'][:],
ne=hf['ne'][:],
j=hf['jtot'][:],
johm=hf['johm'][:],
j_bootstrap=hf['j_bootstrap'][:],
jext=hf['jext'][:],
q=hf['q_face'][:],
s=hf['s_face'][:],
chi_i=hf['chi_face_ion'][:],
chi_e=hf['chi_face_el'][:],
t=hf['t'][:],
rcell_coord=hf['r_cell_norm'][:],
rface_coord=hf['r_face_norm'][:],
Expand All @@ -124,46 +139,72 @@ def __post_init__(self):
ax5 = fig.add_subplot(235)
ax6 = fig.add_subplot(236)

lines = []
lines1 = []
lines2 = []

# TODO(b/323504363): improve efficiency through use of line plotting functions
if comp_plot:
ax1.set_title('(1)=' + args.outfile[0] + ', (2)=' + args.outfile[1])
ax2.set_title('(1)=' + args.outfile[0] + ', (2)=' + args.outfile[1])
else:
ax1.set_title('(1)=' + args.outfile[0])
ax2.set_title('(1)=' + args.outfile[0])
(line,) = ax1.plot(
plotdata1.rface_coord, plotdata1.chi_i[1, :], 'r', label=r'$\chi_i~(1)$'
)
lines1.append(line)
(line,) = ax1.plot(
plotdata1.rface_coord, plotdata1.chi_e[1, :], 'b', label=r'$\chi_e~(1)$'
)
lines1.append(line)
(line,) = ax2.plot(
plotdata.rcell_coord, plotdata.ti[0, :], 'r', label=r'$T_i~(1)$'
plotdata1.rcell_coord, plotdata1.ti[0, :], 'r', label=r'$T_i~(1)$'
)
lines.append(line)
lines1.append(line)
(line,) = ax2.plot(
plotdata.rcell_coord, plotdata.te[0, :], 'b', label=r'$T_e~(1)$'
plotdata1.rcell_coord, plotdata1.te[0, :], 'b', label=r'$T_e~(1)$'
)
lines.append(line)
lines1.append(line)
(line,) = ax3.plot(
plotdata.rcell_coord, plotdata.ne[0, :], 'r', label=r'$n_e~(1)$'
plotdata1.rcell_coord, plotdata1.ne[0, :], 'r', label=r'$n_e~(1)$'
)
lines.append(line)
lines1.append(line)

(line,) = ax4.plot(
plotdata.rcell_coord, plotdata.j[0, :], 'r', label=r'$j_{tot}~(1)$'
plotdata1.rcell_coord, plotdata1.j[0, :], 'r', label=r'$j_{tot}~(1)$'
)
lines1.append(line)
(line,) = ax4.plot(
plotdata1.rcell_coord, plotdata1.johm[0, :], 'b', label=r'$j_{ohm}~(1)$'
)
lines.append(line)
lines1.append(line)
(line,) = ax4.plot(
plotdata.rcell_coord, plotdata.johm[0, :], 'b', label=r'$j_{ohm}~(1)$'
plotdata1.rcell_coord,
plotdata1.j_bootstrap[0, :],
'g',
label=r'$j_{bs}~(1)$',
)
lines.append(line)
lines1.append(line)
(line,) = ax4.plot(
plotdata1.rcell_coord, plotdata1.jext[0, :], 'm', label=r'$j_{ext}~(1)$'
)
lines1.append(line)
(line,) = ax5.plot(
plotdata.rface_coord, plotdata.q[0, :], 'r', label=r'$q~(1)$'
plotdata1.rface_coord, plotdata1.q[0, :], 'r', label=r'$q~(1)$'
)
lines.append(line)
lines1.append(line)
(line,) = ax6.plot(
plotdata.rface_coord, plotdata.s[0, :], 'r', label=r'$\hat{s}~(1)$'
plotdata1.rface_coord, plotdata1.s[0, :], 'r', label=r'$\hat{s}~(1)$'
)
lines.append(line)
lines1.append(line)

# pylint: disable=undefined-variable
if comp_plot:
(line,) = ax1.plot(
plotdata2.rface_coord, plotdata2.chi_i[1, :], 'r', label=r'$\chi_i~(2)$'
)
lines2.append(line)
(line,) = ax1.plot(
plotdata2.rface_coord, plotdata2.chi_e[1, :], 'b', label=r'$\chi_e~(2)$'
)
lines2.append(line)
(line,) = ax2.plot(
plotdata2.rcell_coord, plotdata2.ti[0, :], 'r--', label=r'$T_i (2)$'
)
Expand All @@ -184,6 +225,17 @@ def __post_init__(self):
plotdata2.rcell_coord, plotdata2.johm[0, :], 'b--', label=r'$j_{ohm} (2)$'
)
lines2.append(line)
(line,) = ax4.plot(
plotdata2.rcell_coord,
plotdata2.j_bootstrap[0, :],
'g',
label=r'$j_{bs}~(2)$',
)
lines2.append(line)
(line,) = ax4.plot(
plotdata2.rcell_coord, plotdata2.jext[0, :], 'm', label=r'$j_{ext}~(2)$'
)
lines2.append(line)
(line,) = ax5.plot(
plotdata2.rface_coord, plotdata2.q[0, :], 'r--', label=r'$q (2)$'
)
Expand All @@ -194,33 +246,33 @@ def __post_init__(self):
lines2.append(line)
# pylint: enable=undefined-variable

# TODO(b/323504363): add heat conductivity to output h5
# ax1.set_ylim([0, np.max([plotdata1.ymax_chi_i, plotdata1.ymax_chi_e]) * 1.05])
ax1.set_xlabel('Normalized radius')
ax1.set_ylabel(r'Heat conductivity $[m^2/s]$')
ax1.legend()

ax2.set_ylim([0, plotdata.ymax_t * 1.05])
ax2.set_ylim([0, plotdata1.ymax_t * 1.05])
ax2.set_xlabel('Normalized radius')
ax2.set_ylabel('Temperature [keV]')
ax2.legend()

ax3.set_ylim([0, plotdata.ymax_n * 1.05])
ax3.set_ylim([0, plotdata1.ymax_n * 1.05])
ax3.set_xlabel('Normalized radius')
ax3.set_ylabel(r'Electron density $[10^{20}~m^{-3}]$')
ax3.legend()


ax4.set_ylim([min(plotdata.ymin_j * 1.05, 0), plotdata.ymax_j * 1.05])
ax4.set_ylim([min(plotdata1.ymin_j * 1.05, 0), plotdata1.ymax_j * 1.05])
ax4.set_xlabel('Normalized radius')
ax4.set_ylabel(r'Toroidal current $[A~m^{-2}]$')
ax4.legend()
ax4.legend(fontsize=10)

ax5.set_ylim([0, plotdata.ymax_q * 1.05])
ax5.set_ylim([0, plotdata1.ymax_q * 1.05])
ax5.set_xlabel('Normalized radius')
ax5.set_ylabel('Safety factor')
ax5.legend()

ax6.set_ylim([min(plotdata.ymin_s * 1.05, 0), plotdata.ymax_s * 1.05])
ax6.set_ylim([min(plotdata1.ymin_s * 1.05, 0), plotdata1.ymax_s * 1.05])
ax6.set_xlabel('Normalized radius')
ax6.set_ylabel('Magnetic shear')
ax6.legend()
Expand All @@ -231,17 +283,17 @@ def __post_init__(self):

# pylint: disable=undefined-variable
if comp_plot:
dt = min(plotdata.dt, plotdata2.dt)
dt = min(plotdata1.dt, plotdata2.dt)
else:
dt = plotdata.dt
dt = plotdata1.dt
# pylint: enable=undefined-variable

timeslider = Slider(
axslide,
'Time [s]',
plotdata.tmin,
plotdata.tmax,
valinit=plotdata.tmin,
plotdata1.tmin,
plotdata1.tmax,
valinit=plotdata1.tmin,
valstep=dt,
)

Expand All @@ -250,38 +302,46 @@ def __post_init__(self):

def update(newtime):
"""Update plots with new values following slider manipulation."""
idx = np.abs(plotdata.t - newtime).argmin() # find index closest to new time
datalist = [
plotdata.ti[idx, :],
plotdata.te[idx, :],
plotdata.ne[idx, :],
plotdata.j[idx, :],
plotdata.johm[idx, :],
plotdata.q[idx, :],
plotdata.s[idx, :],
idx = np.abs(plotdata1.t - newtime).argmin() # find index closest to new time
datalist1 = [
plotdata1.chi_i[idx, :],
plotdata1.chi_e[idx, :],
plotdata1.ti[idx, :],
plotdata1.te[idx, :],
plotdata1.ne[idx, :],
plotdata1.j[idx, :],
plotdata1.johm[idx, :],
plotdata1.j_bootstrap[idx, :],
plotdata1.jext[idx, :],
plotdata1.q[idx, :],
plotdata1.s[idx, :],
]
for plotline1, data in zip(lines, datalist):
plotline1.set_ydata(data)
for plotline1, data1 in zip(lines1, datalist1):
plotline1.set_ydata(data1)
if comp_plot:
idx = np.abs(
plotdata2.t - newtime
).argmin() # find index closest to new time
datalist2 = [
plotdata2.chi_i[idx, :],
plotdata2.chi_e[idx, :],
plotdata2.ti[idx, :],
plotdata2.te[idx, :],
plotdata2.ne[idx, :],
plotdata2.j[idx, :],
plotdata2.johm[idx, :],
plotdata2.j_bootstrap[idx, :],
plotdata2.jext[idx, :],
plotdata2.q[idx, :],
plotdata2.s[idx, :],
]
for plotline2, data in zip(lines2, datalist2):
plotline2.set_ydata(data)
for plotline2, data2 in zip(lines2, datalist2):
plotline2.set_ydata(data2)
fig.canvas.draw()


# Call update function when slider value is changed
timeslider.on_changed(update)
plt.show()

# fig.tight_layout()
fig.tight_layout()
23 changes: 18 additions & 5 deletions torax/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,9 @@ def __init__(
stepper, static_argnames=['static_config_slice']
)
self._time_step_calculator = time_step_calculator
self._transport_model = jax_utils.jit(transport_model.__call__,)
self._transport_model = jax_utils.jit(
transport_model.__call__,
)

@property
def stepper(self) -> stepper_lib.Stepper:
Expand Down Expand Up @@ -373,16 +375,16 @@ def body_fun(

output_state = jax_utils.py_while(cond_fun, body_fun, output_state)

# Update current, q, and s profiles based on new psi
# Update total current, q, and s profiles based on new psi
output_state.state.mesh_state = physics.update_jtot_q_face_s_face(
geo=geo,
state=output_state.state.mesh_state,
Rmaj=dynamic_config_slice.Rmaj,
q_correction_factor=dynamic_config_slice.q_correction_factor,
)

# Update bootstrap current based on new state
output_state.state.mesh_state = update_j_bootstrap(
# Update ohmic and bootstrap current based on new state
output_state.state.mesh_state = update_current_distribution(
sources=self._stepper.sources,
dynamic_config_slice=dynamic_config_slice,
geo=geo,
Expand Down Expand Up @@ -915,7 +917,7 @@ def _update_spectator(
spectator.observe(key='Qei', data=output_state.aux.Qei)


def update_j_bootstrap(
def update_current_distribution(
sources: source_profiles_lib.Sources,
dynamic_config_slice: config_slice.DynamicConfigSlice,
geo: geometry.Geometry,
Expand All @@ -929,11 +931,22 @@ def update_j_bootstrap(
state=state,
)

johm = (
state.currents.jtot - bootstrap_profile.j_bootstrap - state.currents.jext
)
johm_face = (
state.currents.jtot_face
- bootstrap_profile.j_bootstrap_face
- state.currents.jext_face
)

currents = dataclasses.replace(
state.currents,
j_bootstrap=bootstrap_profile.j_bootstrap,
j_bootstrap_face=bootstrap_profile.j_bootstrap_face,
I_bootstrap=bootstrap_profile.I_bootstrap,
johm=johm,
johm_face=johm_face,
)
new_state = dataclasses.replace(
state,
Expand Down
Loading

0 comments on commit 5634d3c

Please sign in to comment.