From 33b8c2ff615c81034d89220d485a8a16f42a336f Mon Sep 17 00:00:00 2001 From: Anushan Fernando Date: Wed, 9 Oct 2024 05:20:05 -0700 Subject: [PATCH] Move Ohmic heat source into a separate module. PiperOrigin-RevId: 683996904 --- torax/core_profile_setters.py | 3 +- torax/sim.py | 3 +- torax/sources/default_sources.py | 7 +- torax/sources/ohmic_heat_source.py | 219 +++++++++++++++++++++++++++++ torax/sources/source_models.py | 188 ------------------------- 5 files changed, 227 insertions(+), 193 deletions(-) create mode 100644 torax/sources/ohmic_heat_source.py diff --git a/torax/core_profile_setters.py b/torax/core_profile_setters.py index 4fe98852..9795b775 100644 --- a/torax/core_profile_setters.py +++ b/torax/core_profile_setters.py @@ -31,6 +31,7 @@ from torax.fvm import cell_variable from torax.geometry import Geometry # pylint: disable=g-importing-member from torax.sources import external_current_source +from torax.sources import ohmic_heat_source from torax.sources import source_models as source_models_lib from torax.sources import source_profiles as source_profiles_lib @@ -781,7 +782,7 @@ def initial_core_profiles( # phibdot calculation. psidot = dataclasses.replace( psidot, - value=source_models_lib.calc_psidot( + value=ohmic_heat_source.calc_psidot( dynamic_runtime_params_slice, geo, core_profiles, diff --git a/torax/sim.py b/torax/sim.py index 357a3c1e..278d3c34 100644 --- a/torax/sim.py +++ b/torax/sim.py @@ -48,6 +48,7 @@ from torax.config import runtime_params as general_runtime_params from torax.config import runtime_params_slice from torax.fvm import cell_variable +from torax.sources import ohmic_heat_source from torax.sources import source_models as source_models_lib from torax.sources import source_profiles as source_profiles_lib from torax.spectators import spectator as spectator_lib @@ -1487,7 +1488,7 @@ def update_psidot( psidot = dataclasses.replace( core_profiles.psidot, - value=source_models_lib.calc_psidot( + value=ohmic_heat_source.calc_psidot( dynamic_runtime_params_slice, geo, core_profiles, diff --git a/torax/sources/default_sources.py b/torax/sources/default_sources.py index 880f94b9..1ba50698 100644 --- a/torax/sources/default_sources.py +++ b/torax/sources/default_sources.py @@ -28,6 +28,7 @@ from torax.sources import external_current_source from torax.sources import fusion_heat_source from torax.sources import generic_ion_el_heat_source as ion_el_heat +from torax.sources import ohmic_heat_source from torax.sources import qei_source from torax.sources import runtime_params as runtime_params_lib from torax.sources import source @@ -56,7 +57,7 @@ def get_default_runtime_params( case 'qei_source': return qei_source.RuntimeParams() case 'ohmic_heat_source': - return source_models_lib.OhmicRuntimeParams() + return ohmic_heat_source.OhmicRuntimeParams() case 'bremsstrahlung_heat_sink': return bremsstrahlung_heat_sink.RuntimeParams() case _: @@ -83,7 +84,7 @@ def get_source_type(source_name: str) -> type[source.Source]: case 'qei_source': return qei_source.QeiSource case 'ohmic_heat_source': - return source_models_lib.OhmicHeatSource + return ohmic_heat_source.OhmicHeatSource case 'bremsstrahlung_heat_sink': return bremsstrahlung_heat_sink.BremsstrahlungHeatSink case _: @@ -173,7 +174,7 @@ def get_source_builder_type(source_name: str) -> Any: return qei_source.QeiSourceBuilder case 'ohmic_heat_source': - return source_models_lib.OhmicHeatSourceBuilder + return ohmic_heat_source.OhmicHeatSourceBuilder case 'bremsstrahlung_heat_sink': return bremsstrahlung_heat_sink.BremsstrahlungHeatSinkBuilder diff --git a/torax/sources/ohmic_heat_source.py b/torax/sources/ohmic_heat_source.py new file mode 100644 index 00000000..da469ab4 --- /dev/null +++ b/torax/sources/ohmic_heat_source.py @@ -0,0 +1,219 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Ohmic heat source.""" + +from __future__ import annotations +import dataclasses +import functools +import jax +import jax.numpy as jnp +from torax import constants +from torax import geometry +from torax import jax_utils +from torax import physics +from torax import state +from torax.config import runtime_params_slice +from torax.fvm import convection_terms +from torax.fvm import diffusion_terms +from torax.sources import runtime_params as runtime_params_lib +from torax.sources import source as source_lib +from torax.sources import source_models as source_models_lib + + +@functools.partial( + jax_utils.jit, + static_argnames=[ + 'source_models', + ], +) +def calc_psidot( + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, + geo: geometry.Geometry, + core_profiles: state.CoreProfiles, + source_models: source_models_lib.SourceModels, +) -> jax.Array: + r"""Calculates psidot (loop voltage). Used for the Ohmic electron heat source. + + psidot is an interesting TORAX output, and is thus also saved in + core_profiles. + + psidot = \partial psi / \partial t, and is derived from the same components + that form the psi block in the coupled PDE equations. Thus, a similar + (but abridged) formulation as in sim.calc_coeffs and fvm._calc_c is used here + + Args: + dynamic_runtime_params_slice: Simulation configuration at this timestep + geo: Torus geometry + core_profiles: Core plasma profiles. + source_models: All TORAX source/sinks. + + Returns: + psidot: on cell grid + """ + consts = constants.CONSTANTS + + psi_sources, sigma, sigma_face = source_models_lib.calc_and_sum_sources_psi( + dynamic_runtime_params_slice, + geo, + core_profiles, + source_models, + ) + # Calculate transient term + toc_psi = ( + 1.0 + / dynamic_runtime_params_slice.numerics.resistivity_mult + * geo.rho_norm + * sigma + * consts.mu0 + * 16 + * jnp.pi**2 + * geo.Phib**2 + / geo.F**2 + ) + # Calculate diffusion term coefficient + d_face_psi = geo.g2g3_over_rhon_face + # Add phibdot terms to poloidal flux convection + v_face_psi = ( + -8.0 + * jnp.pi**2 + * consts.mu0 + * geo.Phibdot + * geo.Phib + * sigma_face + * geo.rho_face_norm**2 + / geo.F_face**2 + ) + + # Add effective phibdot poloidal flux source term + ddrnorm_sigma_rnorm2_over_f2 = jnp.gradient( + sigma * geo.rho_norm**2 / geo.F**2, geo.rho_norm + ) + + psi_sources += ( + -8.0 + * jnp.pi**2 + * consts.mu0 + * geo.Phibdot + * geo.Phib + * ddrnorm_sigma_rnorm2_over_f2 + ) + + diffusion_mat, diffusion_vec = diffusion_terms.make_diffusion_terms( + d_face_psi, core_profiles.psi + ) + + # Set the psi convection term for psidot used in ohmic power, always with + # the default 'ghost' mode. Impact of different modes would mildly impact + # Ohmic power at the LCFS which has negligible impact on simulations. + # Allowing it to be configurable introduces more complexity in the code by + # needing to pass in the mode from the static_runtime_params across multiple + # functions. + conv_mat, conv_vec = convection_terms.make_convection_terms( + v_face_psi, + d_face_psi, + core_profiles.psi, + ) + + c_mat = diffusion_mat + conv_mat + c = diffusion_vec + conv_vec + + c += psi_sources + + psidot = (jnp.dot(c_mat, core_profiles.psi.value) + c) / toc_psi + + return psidot + + +def ohmic_model_func( + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, + dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, + geo: geometry.Geometry, + core_profiles: state.CoreProfiles, + source_models: source_models_lib.SourceModels | None = None, +) -> jax.Array: + """Returns the Ohmic source for electron heat equation.""" + del dynamic_source_runtime_params + + if source_models is None: + raise TypeError('source_models is a required argument for ohmic_model_func') + + jtot, _ = physics.calc_jtot_from_psi( + geo, + core_profiles.psi, + ) + + psidot = calc_psidot( + dynamic_runtime_params_slice, + geo, + core_profiles, + source_models, + ) + + pohm = jtot * psidot / (2 * jnp.pi * geo.Rmaj) + return pohm + + +@dataclasses.dataclass +class OhmicRuntimeParams(runtime_params_lib.RuntimeParams): + """Runtime params for OhmicHeatSource.""" + + mode: runtime_params_lib.Mode = runtime_params_lib.Mode.MODEL_BASED + + +# OhmicHeatSource is a special case and defined here to avoid circular +# dependencies, since it depends on the psi sources +@dataclasses.dataclass(kw_only=True, frozen=True, eq=True) +class OhmicHeatSource(source_lib.Source): + """Ohmic heat source for electron heat equation. + + Pohm = jtor * psidot /(2*pi*Rmaj), related to electric power formula P = IV. + """ + + # output_shape_getter is removed from __init__ as it is fixed to this value. + output_shape_getter: source_lib.SourceOutputShapeFunction = dataclasses.field( + init=False, + default_factory=lambda: source_lib.get_cell_profile_shape, + ) + # Users must pass in a pointer to the complete set of sources to this object. + source_models: source_models_lib.SourceModels + + supported_modes: tuple[runtime_params_lib.Mode, ...] = ( + runtime_params_lib.Mode.ZERO, + runtime_params_lib.Mode.MODEL_BASED, + ) + + # Freeze these params and do not include them in the __init__. + affected_core_profiles: tuple[source_lib.AffectedCoreProfile, ...] = ( + dataclasses.field( + init=False, + default=(source_lib.AffectedCoreProfile.TEMP_EL,), + ) + ) + + # The model function is fixed to ohmic_model_func because that is the only + # supported implementation of this source. + # However, since this is a param in the parent dataclass, we need to (a) + # remove the parameter from the init args and (b) set the default to the + # desired value. + model_func: source_lib.SourceProfileFunction | None = dataclasses.field( + init=False, + default_factory=lambda: ohmic_model_func, + ) + + +OhmicHeatSourceBuilder = source_lib.make_source_builder( + OhmicHeatSource, + links_back=True, + runtime_params_type=OhmicRuntimeParams, +) diff --git a/torax/sources/source_models.py b/torax/sources/source_models.py index edc724cf..fc5ae3c8 100644 --- a/torax/sources/source_models.py +++ b/torax/sources/source_models.py @@ -16,7 +16,6 @@ from __future__ import annotations -import dataclasses import functools import jax @@ -24,11 +23,8 @@ from torax import constants from torax import geometry from torax import jax_utils -from torax import physics from torax import state from torax.config import runtime_params_slice -from torax.fvm import convection_terms -from torax.fvm import diffusion_terms from torax.sources import bootstrap_current_source from torax.sources import external_current_source from torax.sources import qei_source as qei_source_lib @@ -494,190 +490,6 @@ def calc_and_sum_sources_psi( ) -@functools.partial( - jax_utils.jit, - static_argnames=[ - 'source_models', - ], -) -def calc_psidot( - dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - geo: geometry.Geometry, - core_profiles: state.CoreProfiles, - source_models: SourceModels, -) -> jax.Array: - r"""Calculates psidot (loop voltage). Used for the Ohmic electron heat source. - - psidot is an interesting TORAX output, and is thus also saved in - core_profiles. - - psidot = \partial psi / \partial t, and is derived from the same components - that form the psi block in the coupled PDE equations. Thus, a similar - (but abridged) formulation as in sim.calc_coeffs and fvm._calc_c is used here - - Args: - dynamic_runtime_params_slice: Simulation configuration at this timestep - geo: Torus geometry - core_profiles: Core plasma profiles. - source_models: All TORAX source/sinks. - - Returns: - psidot: on cell grid - """ - consts = constants.CONSTANTS - - psi_sources, sigma, sigma_face = calc_and_sum_sources_psi( - dynamic_runtime_params_slice, - geo, - core_profiles, - source_models, - ) - # Calculate transient term - toc_psi = ( - 1.0 - / dynamic_runtime_params_slice.numerics.resistivity_mult - * geo.rho_norm - * sigma - * consts.mu0 - * 16 - * jnp.pi**2 - * geo.Phib**2 - / geo.F**2 - ) - # Calculate diffusion term coefficient - d_face_psi = geo.g2g3_over_rhon_face - # Add phibdot terms to poloidal flux convection - v_face_psi = ( - -8.0 - * jnp.pi**2 - * consts.mu0 - * geo.Phibdot - * geo.Phib - * sigma_face - * geo.rho_face_norm**2 - / geo.F_face**2 - ) - - # Add effective phibdot poloidal flux source term - ddrnorm_sigma_rnorm2_over_f2 = jnp.gradient( - sigma * geo.rho_norm**2 / geo.F**2, geo.rho_norm - ) - - psi_sources += ( - -8.0 - * jnp.pi**2 - * consts.mu0 - * geo.Phibdot - * geo.Phib - * ddrnorm_sigma_rnorm2_over_f2 - ) - - diffusion_mat, diffusion_vec = diffusion_terms.make_diffusion_terms( - d_face_psi, core_profiles.psi - ) - - # Set the psi convection term for psidot used in ohmic power, always with - # the default 'ghost' mode. Impact of different modes would mildly impact - # Ohmic power at the LCFS which has negligible impact on simulations. - # Allowing it to be configurable introduces more complexity in the code by - # needing to pass in the mode from the static_runtime_params across multiple - # functions. - conv_mat, conv_vec = convection_terms.make_convection_terms( - v_face_psi, - d_face_psi, - core_profiles.psi, - ) - - c_mat = diffusion_mat + conv_mat - c = diffusion_vec + conv_vec - - c += psi_sources - - psidot = (jnp.dot(c_mat, core_profiles.psi.value) + c) / toc_psi - - return psidot - - -def ohmic_model_func( - dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, - geo: geometry.Geometry, - core_profiles: state.CoreProfiles, - source_models: SourceModels | None = None, -) -> jax.Array: - """Returns the Ohmic source for electron heat equation.""" - del dynamic_source_runtime_params - - if source_models is None: - raise TypeError('source_models is a required argument for ohmic_model_func') - - jtot, _ = physics.calc_jtot_from_psi( - geo, - core_profiles.psi, - ) - - psidot = calc_psidot( - dynamic_runtime_params_slice, - geo, - core_profiles, - source_models, - ) - - pohm = jtot * psidot / (2 * jnp.pi * geo.Rmaj) - return pohm - - -# OhmicHeatSource is a special case and defined here to avoid circular -# dependencies, since it depends on the psi sources -@dataclasses.dataclass(kw_only=True, frozen=True, eq=True) -class OhmicHeatSource(source_lib.Source): - """Ohmic heat source for electron heat equation. - - Pohm = jtor * psidot /(2*pi*Rmaj), related to electric power formula P = IV. - """ - # output_shape_getter is removed from __init__ as it is fixed to this value. - output_shape_getter: source_lib.SourceOutputShapeFunction = dataclasses.field( - init=False, - default_factory=lambda: source_lib.get_cell_profile_shape, - ) - # Users must pass in a pointer to the complete set of sources to this object. - source_models: SourceModels - - supported_modes: tuple[runtime_params_lib.Mode, ...] = ( - runtime_params_lib.Mode.ZERO, - runtime_params_lib.Mode.MODEL_BASED, - ) - - # Freeze these params and do not include them in the __init__. - affected_core_profiles: tuple[source_lib.AffectedCoreProfile, ...] = ( - dataclasses.field( - init=False, - default=(source_lib.AffectedCoreProfile.TEMP_EL,), - ) - ) - - # The model function is fixed to ohmic_model_func because that is the only - # supported implementation of this source. - # However, since this is a param in the parent dataclass, we need to (a) - # remove the parameter from the init args and (b) set the default to the - # desired value. - model_func: source_lib.SourceProfileFunction | None = dataclasses.field( - init=False, - default_factory=lambda: ohmic_model_func, - ) - - -@dataclasses.dataclass -class OhmicRuntimeParams(runtime_params_lib.RuntimeParams): - """Runtime params for OhmicHeatSource.""" - mode: runtime_params_lib.Mode = runtime_params_lib.Mode.MODEL_BASED - - -OhmicHeatSourceBuilder = source_lib.make_source_builder( - OhmicHeatSource, links_back=True, runtime_params_type=OhmicRuntimeParams, -) - - class SourceModels: """Source/sink models for the different equations being evolved in Torax.