Skip to content

Commit

Permalink
Introduce named constants for sources.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 686485722
  • Loading branch information
Nush395 authored and Torax team committed Oct 16, 2024
1 parent 69f0313 commit 1af8cc8
Show file tree
Hide file tree
Showing 14 changed files with 185 additions and 111 deletions.
120 changes: 67 additions & 53 deletions torax/config/tests/runtime_params_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,16 +168,20 @@ def test_source_formula_config_has_time_dependent_params(self):
dcs = runtime_params_slice_lib.DynamicRuntimeParamsSliceProvider(
runtime_params=runtime_params,
sources={
'gas_puff_source': electron_density_sources.GasPuffRuntimeParams(
puff_decay_length={0.0: 0.0, 1.0: 4.0},
S_puff_tot={0.0: 0.0, 1.0: 5.0},
electron_density_sources.GAS_PUFF_SOURCE_NAME: (
electron_density_sources.GasPuffRuntimeParams(
puff_decay_length={0.0: 0.0, 1.0: 4.0},
S_puff_tot={0.0: 0.0, 1.0: 5.0},
)
),
'pellet_source': electron_density_sources.PelletRuntimeParams(
pellet_width={0.0: 0.0, 1.0: 1.0},
pellet_deposition_location={0.0: 0.0, 1.0: 2.0},
S_pellet_tot={0.0: 0.0, 1.0: 3.0},
electron_density_sources.PELLET_SOURCE_NAME: (
electron_density_sources.PelletRuntimeParams(
pellet_width={0.0: 0.0, 1.0: 1.0},
pellet_deposition_location={0.0: 0.0, 1.0: 2.0},
S_pellet_tot={0.0: 0.0, 1.0: 3.0},
)
),
'nbi_particle_source': (
electron_density_sources.GENERIC_PARTICLE_SOURCE_NAME: (
electron_density_sources.NBIParticleRuntimeParams(
nbi_particle_width={0.0: 0.0, 1.0: 6.0},
nbi_deposition_location={0.0: 0.0, 1.0: 7.0},
Expand All @@ -189,9 +193,13 @@ def test_source_formula_config_has_time_dependent_params(self):
)(
t=0.5,
)
pellet_source = dcs.sources['pellet_source']
gas_puff_source = dcs.sources['gas_puff_source']
nbi_particle_source = dcs.sources['nbi_particle_source']
pellet_source = dcs.sources[electron_density_sources.PELLET_SOURCE_NAME]
gas_puff_source = dcs.sources[
electron_density_sources.GAS_PUFF_SOURCE_NAME
]
nbi_particle_source = dcs.sources[
electron_density_sources.GENERIC_PARTICLE_SOURCE_NAME
]
assert isinstance(
pellet_source,
electron_density_sources.DynamicPelletRuntimeParams,
Expand Down Expand Up @@ -222,59 +230,57 @@ def test_source_formula_config_has_time_dependent_params(self):
dcs = runtime_params_slice_lib.DynamicRuntimeParamsSliceProvider(
runtime_params=runtime_params,
sources={
'gas_puff_source': sources_params_lib.RuntimeParams(
formula=formula_config.Exponential(
total={0.0: 0.0, 1.0: 1.0},
c1={0.0: 0.0, 1.0: 2.0},
c2={0.0: 0.0, 1.0: 3.0},
electron_density_sources.GAS_PUFF_SOURCE_NAME: (
sources_params_lib.RuntimeParams(
formula=formula_config.Exponential(
total={0.0: 0.0, 1.0: 1.0},
c1={0.0: 0.0, 1.0: 2.0},
c2={0.0: 0.0, 1.0: 3.0},
)
)
),
},
torax_mesh=self._geo.torax_mesh,
)(
t=0.25,
)
gas_puff_source = dcs.sources['gas_puff_source']
gas_puff_source = dcs.sources[
electron_density_sources.GAS_PUFF_SOURCE_NAME
]
assert isinstance(
gas_puff_source.formula,
formula_config.DynamicExponential,
)
np.testing.assert_allclose(
gas_puff_source.formula.total, 0.25
)
np.testing.assert_allclose(gas_puff_source.formula.total, 0.25)
np.testing.assert_allclose(gas_puff_source.formula.c1, 0.5)
np.testing.assert_allclose(
gas_puff_source.formula.c2, 0.75
)
np.testing.assert_allclose(gas_puff_source.formula.c2, 0.75)

with self.subTest('gaussian_formula'):
runtime_params = general_runtime_params.GeneralRuntimeParams()
dcs = runtime_params_slice_lib.DynamicRuntimeParamsSliceProvider(
runtime_params=runtime_params,
sources={
'gas_puff_source': sources_params_lib.RuntimeParams(
formula=formula_config.Gaussian(
total={0.0: 0.0, 1.0: 1.0},
c1={0.0: 0.0, 1.0: 2.0},
c2={0.0: 0.0, 1.0: 3.0},
electron_density_sources.GAS_PUFF_SOURCE_NAME: (
sources_params_lib.RuntimeParams(
formula=formula_config.Gaussian(
total={0.0: 0.0, 1.0: 1.0},
c1={0.0: 0.0, 1.0: 2.0},
c2={0.0: 0.0, 1.0: 3.0},
)
)
),
},
torax_mesh=self._geo.torax_mesh,
)(
t=0.25,
)
gas_puff_source = dcs.sources['gas_puff_source']
assert isinstance(
gas_puff_source.formula, formula_config.DynamicGaussian
)
np.testing.assert_allclose(
gas_puff_source.formula.total, 0.25
)
gas_puff_source = dcs.sources[
electron_density_sources.GAS_PUFF_SOURCE_NAME
]
assert isinstance(gas_puff_source.formula, formula_config.DynamicGaussian)
np.testing.assert_allclose(gas_puff_source.formula.total, 0.25)
np.testing.assert_allclose(gas_puff_source.formula.c1, 0.5)
np.testing.assert_allclose(
gas_puff_source.formula.c2, 0.75
)
np.testing.assert_allclose(gas_puff_source.formula.c2, 0.75)

def test_wext_in_dynamic_runtime_params_cannot_be_negative(self):
"""Tests that wext cannot be negative."""
Expand All @@ -283,30 +289,34 @@ def test_wext_in_dynamic_runtime_params_cannot_be_negative(self):
runtime_params=runtime_params,
transport=transport_params_lib.RuntimeParams(),
sources={
'jext': external_current_source.RuntimeParams(
wext={0.0: 1.0, 1.0: -1.0}
external_current_source.SOURCE_NAME: (
external_current_source.RuntimeParams(
wext={0.0: 1.0, 1.0: -1.0}
)
),
},
stepper=stepper_params_lib.RuntimeParams(),
torax_mesh=self._geo.torax_mesh,
)
# While wext is positive, this should be fine.
dcs = dcs_provider(t=0.0,)
jext = dcs.sources['jext']
assert isinstance(
jext, external_current_source.DynamicRuntimeParams
dcs = dcs_provider(
t=0.0,
)
jext = dcs.sources[external_current_source.SOURCE_NAME]
assert isinstance(jext, external_current_source.DynamicRuntimeParams)
np.testing.assert_allclose(jext.wext, 1.0)
# Even 0 should be fine.
dcs = dcs_provider(t=0.5,)
jext = dcs.sources['jext']
assert isinstance(
jext, external_current_source.DynamicRuntimeParams
dcs = dcs_provider(
t=0.5,
)
jext = dcs.sources[external_current_source.SOURCE_NAME]
assert isinstance(jext, external_current_source.DynamicRuntimeParams)
np.testing.assert_allclose(jext.wext, 0.0)
# But negative values will cause an error.
with self.assertRaises(RuntimeError):
dcs_provider(t=1.0,)
dcs_provider(
t=1.0,
)

@parameterized.parameters(
(
Expand Down Expand Up @@ -471,7 +481,9 @@ def test_update_dynamic_slice_provider_updates_sources(
"""Tests that the dynamic slice provider can be updated."""
runtime_params = general_runtime_params.GeneralRuntimeParams()
source_models_builder = default_sources.get_default_sources_builder()
source_models_builder.runtime_params['jext'].Iext = 1.0
source_models_builder.runtime_params[
external_current_source.SOURCE_NAME
].Iext = 1.0
geo = geometry.build_circular_geometry(n_rho=4)
provider = runtime_params_slice_lib.DynamicRuntimeParamsSliceProvider(
runtime_params=runtime_params,
Expand All @@ -485,15 +497,17 @@ def test_update_dynamic_slice_provider_updates_sources(
self.assertIn(key, dcs.sources)

# Update an interpolated variable.
source_models_builder.runtime_params['jext'].Iext = 2.0
source_models_builder.runtime_params[
external_current_source.SOURCE_NAME
].Iext = 2.0

# Check pre-update that nothing has changed.
dcs = provider(
t=0.0,
)
for key in source_models_builder.runtime_params.keys():
self.assertIn(key, dcs.sources)
jext_source = dcs.sources['jext']
jext_source = dcs.sources[external_current_source.SOURCE_NAME]
assert isinstance(jext_source, external_current_source.DynamicRuntimeParams)
self.assertEqual(jext_source.Iext, 1.0)

Expand All @@ -508,7 +522,7 @@ def test_update_dynamic_slice_provider_updates_sources(
)
for key in source_models_builder.runtime_params.keys():
self.assertIn(key, dcs.sources)
jext_source = dcs.sources['jext']
jext_source = dcs.sources[external_current_source.SOURCE_NAME]
assert isinstance(jext_source, external_current_source.DynamicRuntimeParams)
self.assertEqual(jext_source.Iext, 2.0)

Expand Down
5 changes: 4 additions & 1 deletion torax/sources/bootstrap_current_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
from torax.sources import source_profiles


SOURCE_NAME = 'j_bootstrap'


@dataclasses.dataclass(kw_only=True)
class RuntimeParams(runtime_params_lib.RuntimeParams):
"""Configuration parameters for the bootstrap current source."""
Expand Down Expand Up @@ -184,7 +187,7 @@ def get_source_profile_for_affected_core_profile(
) -> jax.Array:
return jnp.where(
affected_core_profile in self.affected_core_profiles_ints,
profile['j_bootstrap'],
profile[SOURCE_NAME],
jnp.zeros_like(geo.rho),
)

Expand Down
3 changes: 3 additions & 0 deletions torax/sources/bremsstrahlung_heat_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
from torax.sources import source_models


SOURCE_NAME = 'bremsstrahlung_heat_sink'


@dataclasses.dataclass(kw_only=True)
class RuntimeParams(runtime_params_lib.RuntimeParams):
use_relativistic_correction: bool = False
Expand Down
9 changes: 9 additions & 0 deletions torax/sources/electron_density_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ def _calc_puff_source(
)


GAS_PUFF_SOURCE_NAME = 'gas_puff_source'


@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
class GasPuffSource(source.Source):
"""Gas puff source for the ne equation."""
Expand Down Expand Up @@ -183,6 +186,9 @@ def _calc_nbi_source(
)


GENERIC_PARTICLE_SOURCE_NAME = 'nbi_particle_source'


@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
class NBIParticleSource(source.Source):
"""Neutral-beam injection source for the ne equation."""
Expand Down Expand Up @@ -257,6 +263,9 @@ def _calc_pellet_source(
)


PELLET_SOURCE_NAME = 'pellet_source'


@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
class PelletSource(source.Source):
"""Pellet source for the ne equation."""
Expand Down
1 change: 1 addition & 0 deletions torax/sources/external_current_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from torax.sources import source


SOURCE_NAME = 'jext'
# pylint: disable=invalid-name


Expand Down
3 changes: 3 additions & 0 deletions torax/sources/fusion_heat_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
from torax.sources import source


SOURCE_NAME = 'fusion_heat_source'


def calc_fusion(
geo: geometry.Geometry,
core_profiles: state.CoreProfiles,
Expand Down
1 change: 1 addition & 0 deletions torax/sources/generic_ion_el_heat_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from torax.sources import source


SOURCE_NAME = 'generic_ion_el_heat_source'
# Many variables throughout this function are capitalized based on physics
# notational conventions rather than on Google Python style
# pylint: disable=invalid-name
Expand Down
3 changes: 3 additions & 0 deletions torax/sources/ohmic_heat_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
from torax.sources import source_models as source_models_lib


SOURCE_NAME = 'ohmic_heat_source'


@functools.partial(
jax_utils.jit,
static_argnames=[
Expand Down
1 change: 1 addition & 0 deletions torax/sources/qei_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from torax.sources import source_profiles


SOURCE_NAME = 'qei_source'
# pylint: disable=invalid-name


Expand Down
20 changes: 10 additions & 10 deletions torax/sources/register_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,53 +109,53 @@ def get_registered_source(source_name: str) -> RegisteredSource:
def register_torax_sources():
"""Register a set of sources commonly used in TORAX."""
register_new_source(
'j_bootstrap',
bootstrap_current_source.SOURCE_NAME,
source_class=bootstrap_current_source.BootstrapCurrentSource,
default_runtime_params_class=bootstrap_current_source.RuntimeParams,
)
register_new_source(
'jext',
external_current_source.SOURCE_NAME,
external_current_source.ExternalCurrentSource,
default_runtime_params_class=external_current_source.RuntimeParams,
)
register_new_source(
'nbi_particle_source',
electron_density_sources.GENERIC_PARTICLE_SOURCE_NAME,
electron_density_sources.NBIParticleSource,
default_runtime_params_class=electron_density_sources.NBIParticleRuntimeParams,
)
register_new_source(
'gas_puff_source',
electron_density_sources.GAS_PUFF_SOURCE_NAME,
electron_density_sources.GasPuffSource,
default_runtime_params_class=electron_density_sources.GasPuffRuntimeParams,
)
register_new_source(
'pellet_source',
electron_density_sources.PELLET_SOURCE_NAME,
electron_density_sources.PelletSource,
default_runtime_params_class=electron_density_sources.PelletRuntimeParams,
)
register_new_source(
'generic_ion_el_heat_source',
ion_el_heat.SOURCE_NAME,
ion_el_heat.GenericIonElectronHeatSource,
default_runtime_params_class=ion_el_heat.RuntimeParams,
)
register_new_source(
'fusion_heat_source',
fusion_heat_source.SOURCE_NAME,
fusion_heat_source.FusionHeatSource,
default_runtime_params_class=fusion_heat_source.FusionHeatSourceRuntimeParams
)
register_new_source(
'qei_source',
qei_source.SOURCE_NAME,
qei_source.QeiSource,
default_runtime_params_class=qei_source.RuntimeParams,
)
register_new_source(
'ohmic_heat_source',
ohmic_heat_source.SOURCE_NAME,
ohmic_heat_source.OhmicHeatSource,
default_runtime_params_class=ohmic_heat_source.OhmicRuntimeParams,
links_back=True,
)
register_new_source(
'bremsstrahlung_heat_sink',
bremsstrahlung_heat_sink.SOURCE_NAME,
bremsstrahlung_heat_sink.BremsstrahlungHeatSink,
default_runtime_params_class=bremsstrahlung_heat_sink.RuntimeParams,
)
Expand Down
Loading

0 comments on commit 1af8cc8

Please sign in to comment.