From 777a69fe75dac10aaa9cff15b5fd0125f4d37765 Mon Sep 17 00:00:00 2001 From: Anushan Fernando Date: Wed, 2 Oct 2024 09:45:21 -0700 Subject: [PATCH] Remove SingleProfileSource. PiperOrigin-RevId: 681488167 --- torax/sources/bremsstrahlung_heat_sink.py | 2 +- torax/sources/current_density_sources.py | 8 +- torax/sources/electron_density_sources.py | 28 +++- torax/sources/ion_el_heat_sources.py | 18 +-- torax/sources/source.py | 169 +++------------------- torax/sources/source_models.py | 8 +- torax/sources/tests/formulas.py | 2 +- torax/sources/tests/source.py | 77 +++------- torax/sources/tests/source_models.py | 18 ++- torax/sources/tests/test_lib.py | 4 +- torax/tests/sim_custom_sources.py | 2 +- torax/tests/sim_output_source_profiles.py | 4 +- 12 files changed, 100 insertions(+), 240 deletions(-) diff --git a/torax/sources/bremsstrahlung_heat_sink.py b/torax/sources/bremsstrahlung_heat_sink.py index 2c6e86a6..172496f2 100644 --- a/torax/sources/bremsstrahlung_heat_sink.py +++ b/torax/sources/bremsstrahlung_heat_sink.py @@ -139,7 +139,7 @@ def bremsstrahlung_model_func( @dataclasses.dataclass(kw_only=True, frozen=True, eq=True) -class BremsstrahlungHeatSink(source.SingleProfileSource): +class BremsstrahlungHeatSink(source.Source): """Fusion heat source for both ion and electron heat.""" supported_modes: tuple[runtime_params_lib.Mode, ...] = ( runtime_params_lib.Mode.ZERO, diff --git a/torax/sources/current_density_sources.py b/torax/sources/current_density_sources.py index d3a0189c..33f0411d 100644 --- a/torax/sources/current_density_sources.py +++ b/torax/sources/current_density_sources.py @@ -31,7 +31,7 @@ @dataclasses.dataclass(kw_only=True, frozen=True, eq=True) -class ECRHCurrentSource(source.SingleProfileSource): +class ECRHCurrentSource(source.Source): """ECRH current density source for the psi equation.""" affected_core_profiles: tuple[source.AffectedCoreProfile, ...] = ( source.AffectedCoreProfile.PSI, @@ -39,7 +39,7 @@ class ECRHCurrentSource(source.SingleProfileSource): @dataclasses.dataclass(kw_only=True, frozen=True, eq=True) -class ICRHCurrentSource(source.SingleProfileSource): +class ICRHCurrentSource(source.Source): """ICRH current density source for the psi equation.""" affected_core_profiles: tuple[source.AffectedCoreProfile, ...] = ( source.AffectedCoreProfile.PSI, @@ -47,7 +47,7 @@ class ICRHCurrentSource(source.SingleProfileSource): @dataclasses.dataclass(kw_only=True, frozen=True, eq=True) -class LHCurrentSource(source.SingleProfileSource): +class LHCurrentSource(source.Source): """LH current density source for the psi equation.""" affected_core_profiles: tuple[source.AffectedCoreProfile, ...] = ( source.AffectedCoreProfile.PSI, @@ -55,7 +55,7 @@ class LHCurrentSource(source.SingleProfileSource): @dataclasses.dataclass(kw_only=True, frozen=True, eq=True) -class NBICurrentSource(source.SingleProfileSource): +class NBICurrentSource(source.Source): """NBI current density source for the psi equation.""" affected_core_profiles: tuple[source.AffectedCoreProfile, ...] = ( source.AffectedCoreProfile.PSI, diff --git a/torax/sources/electron_density_sources.py b/torax/sources/electron_density_sources.py index 140a1a16..5f9a0a36 100644 --- a/torax/sources/electron_density_sources.py +++ b/torax/sources/electron_density_sources.py @@ -95,8 +95,13 @@ def _calc_puff_source( @dataclasses.dataclass(kw_only=True, frozen=True, eq=True) -class GasPuffSource(source.SingleProfileSource): +class GasPuffSource(source.Source): """Gas puff source for the ne equation.""" + # output_shape_getter is removed from __init__ as it is fixed to this value. + output_shape_getter: source.SourceOutputShapeFunction = dataclasses.field( + init=False, + default_factory=lambda: source.get_cell_profile_shape, + ) formula: source.SourceProfileFunction = _calc_puff_source affected_core_profiles: tuple[source.AffectedCoreProfile, ...] = ( source.AffectedCoreProfile.NE, @@ -183,8 +188,13 @@ def _calc_nbi_source( @dataclasses.dataclass(kw_only=True, frozen=True, eq=True) -class NBIParticleSource(source.SingleProfileSource): +class NBIParticleSource(source.Source): """Neutral-beam injection source for the ne equation.""" + # output_shape_getter is removed from __init__ as it is fixed to this value. + output_shape_getter: source.SourceOutputShapeFunction = dataclasses.field( + init=False, + default_factory=lambda: source.get_cell_profile_shape, + ) formula: source.SourceProfileFunction = _calc_nbi_source affected_core_profiles: tuple[source.AffectedCoreProfile, ...] = ( source.AffectedCoreProfile.NE, @@ -256,8 +266,13 @@ def _calc_pellet_source( @dataclasses.dataclass(kw_only=True, frozen=True, eq=True) -class PelletSource(source.SingleProfileSource): +class PelletSource(source.Source): """Pellet source for the ne equation.""" + # output_shape_getter is removed from __init__ as it is fixed to this value. + output_shape_getter: source.SourceOutputShapeFunction = dataclasses.field( + init=False, + default_factory=lambda: source.get_cell_profile_shape, + ) formula: source.SourceProfileFunction = _calc_pellet_source affected_core_profiles: tuple[source.AffectedCoreProfile, ...] = ( source.AffectedCoreProfile.NE, @@ -274,8 +289,13 @@ class PelletSource(source.SingleProfileSource): @dataclasses.dataclass(kw_only=True, frozen=True, eq=True) -class RecombinationDensitySink(source.SingleProfileSource): +class RecombinationDensitySink(source.Source): """Recombination sink for the electron density equation.""" + # output_shape_getter is removed from __init__ as it is fixed to this value. + output_shape_getter: source.SourceOutputShapeFunction = dataclasses.field( + init=False, + default_factory=lambda: source.get_cell_profile_shape, + ) affected_core_profiles: tuple[source.AffectedCoreProfile, ...] = ( source.AffectedCoreProfile.NE, ) diff --git a/torax/sources/ion_el_heat_sources.py b/torax/sources/ion_el_heat_sources.py index 2bd8238f..05aebb8d 100644 --- a/torax/sources/ion_el_heat_sources.py +++ b/torax/sources/ion_el_heat_sources.py @@ -31,7 +31,7 @@ @dataclasses.dataclass(kw_only=True, frozen=True, eq=True) -class ChargeExchangeHeatSink(source.SingleProfileSource): +class ChargeExchangeHeatSink(source.Source): """Charge exchange loss term for the ion temp equation.""" affected_core_profiles: tuple[source.AffectedCoreProfile, ...] = ( source.AffectedCoreProfile.TEMP_ION, @@ -39,7 +39,7 @@ class ChargeExchangeHeatSink(source.SingleProfileSource): @dataclasses.dataclass(kw_only=True, frozen=True, eq=True) -class CyclotronRadiationHeatSink(source.SingleProfileSource): +class CyclotronRadiationHeatSink(source.Source): """Cyclotron radiation loss term for the electron temp equation.""" affected_core_profiles: tuple[source.AffectedCoreProfile, ...] = ( source.AffectedCoreProfile.TEMP_EL, @@ -47,7 +47,7 @@ class CyclotronRadiationHeatSink(source.SingleProfileSource): @dataclasses.dataclass(kw_only=True, frozen=True, eq=True) -class ECRHHeatSource(source.SingleProfileSource): +class ECRHHeatSource(source.Source): """ECRH heat source for the electron temp equation.""" affected_core_profiles: tuple[source.AffectedCoreProfile, ...] = ( source.AffectedCoreProfile.TEMP_EL, @@ -55,7 +55,7 @@ class ECRHHeatSource(source.SingleProfileSource): @dataclasses.dataclass(kw_only=True, frozen=True, eq=True) -class ICRHHeatSource(source.SingleProfileSource): +class ICRHHeatSource(source.Source): """ICRH heat source for the ion temp equation.""" affected_core_profiles: tuple[source.AffectedCoreProfile, ...] = ( source.AffectedCoreProfile.TEMP_ION, @@ -63,7 +63,7 @@ class ICRHHeatSource(source.SingleProfileSource): @dataclasses.dataclass(kw_only=True, frozen=True, eq=True) -class LHHeatSource(source.SingleProfileSource): +class LHHeatSource(source.Source): """LH heat source for the electron temp equation.""" affected_core_profiles: tuple[source.AffectedCoreProfile, ...] = ( source.AffectedCoreProfile.TEMP_EL, @@ -71,7 +71,7 @@ class LHHeatSource(source.SingleProfileSource): @dataclasses.dataclass(kw_only=True, frozen=True, eq=True) -class LineRadiationHeatSink(source.SingleProfileSource): +class LineRadiationHeatSink(source.Source): """Line radiation loss sink for the electron temp equation.""" affected_core_profiles: tuple[source.AffectedCoreProfile, ...] = ( source.AffectedCoreProfile.TEMP_EL, @@ -79,7 +79,7 @@ class LineRadiationHeatSink(source.SingleProfileSource): @dataclasses.dataclass(kw_only=True, frozen=True, eq=True) -class NBIElectronHeatSource(source.SingleProfileSource): +class NBIElectronHeatSource(source.Source): """NBI heat source for the electron temp equation.""" affected_core_profiles: tuple[source.AffectedCoreProfile, ...] = ( source.AffectedCoreProfile.TEMP_EL, @@ -87,7 +87,7 @@ class NBIElectronHeatSource(source.SingleProfileSource): @dataclasses.dataclass(kw_only=True, frozen=True, eq=True) -class NBIIonHeatSource(source.SingleProfileSource): +class NBIIonHeatSource(source.Source): """NBI heat source for the ion temp equation.""" affected_core_profiles: tuple[source.AffectedCoreProfile, ...] = ( source.AffectedCoreProfile.TEMP_ION, @@ -95,7 +95,7 @@ class NBIIonHeatSource(source.SingleProfileSource): @dataclasses.dataclass(kw_only=True, frozen=True, eq=True) -class RecombinationHeatSink(source.SingleProfileSource): +class RecombinationHeatSink(source.Source): """Recombination loss sink for the electron temp equation.""" affected_core_profiles: tuple[source.AffectedCoreProfile, ...] = ( source.AffectedCoreProfile.TEMP_EL, diff --git a/torax/sources/source.py b/torax/sources/source.py index d7322df1..969ad249 100644 --- a/torax/sources/source.py +++ b/torax/sources/source.py @@ -136,7 +136,12 @@ class Source: runtime_params_lib.Mode.PRESCRIBED, ) - output_shape_getter: SourceOutputShapeFunction = get_cell_profile_shape + # output_shape_getter is removed from __init__ as it is fixed to this value. + # For different output shapes, override this attribute. + output_shape_getter: SourceOutputShapeFunction = dataclasses.field( + init=False, + default_factory=lambda: get_cell_profile_shape, + ) model_func: SourceProfileFunction | None = None @@ -272,153 +277,22 @@ def get_source_profile_for_affected_core_profile( """ # Get a valid index that defaults to 0 if not present. affected_core_profile_ints = self.affected_core_profiles_ints - idx = jnp.argmax( - jnp.asarray(affected_core_profile_ints) == affected_core_profile - ) - return jnp.where( - affected_core_profile in affected_core_profile_ints, - profile[idx, ...], - jnp.zeros_like(geo.rho), - ) - - -@dataclasses.dataclass(kw_only=True, frozen=True, eq=True) -class SingleProfileSource(Source): - """Source providing a single output profile on the cell grid. - - Most sources in TORAX are instances (or subclasses) of this class. - - You can define custom sources inline when constructing the full list of - sources to use in TORAX. - - .. code-block:: python - - # Define an electron-density source with a Gaussian profile. - my_custom_source_builder = source.SingleProfileSourceBuilder( - supported_modes=( - runtime_params_lib.Mode.ZERO, - runtime_params_lib.Mode.FORMULA_BASED, - ), - affected_core_profiles=[source.AffectedCoreProfile.NE], - formula=formulas.Gaussian(my_custom_source_name), - ) - # Define its runtime parameters (this could be done in the constructor as - # well). - my_custom_source_builder.runtime_params = runtime_params_lib.RuntimeParams( - mode=runtime_params_lib.Mode.FORMULA_BASED, - formula=formula_config.Gaussian( - total=1.0, - c1=2.0, - c2=3.0, - ), - ) - all_torax_sources_builder = source_models_lib.SourceModelsBuilder( - sources_builder={ - 'my_custom_source': my_custom_source_builder, - } - ) - - If you want to create a subclass of SingleProfileSource with frozen - parameters, you can provide default implementations/attributes. This is an - example of a model-based source with a frozen custom model that cannot be - changed by a runtime_params, along with custom runtime parameters specific to - this - source: - - .. code-block:: python - - @dataclasses.dataclass(kw_only=True) - class FooRuntimeParams(runtime_params_lib.RuntimeParams): - foo_param: runtime_params_lib.TimeInterpolatedInput - bar_param: float - - def (build_dynamic_params(self, t: chex.Numeric) - -> DynamicFooRuntimeParams): - return DynamicFooRuntimeParams( - **config_args.get_init_kwargs( - input_config=self, - output_type=DynamicFooRuntimeParams, - t=t, - ) + if len(affected_core_profile_ints) == 1: + return jnp.where( + affected_core_profile in self.affected_core_profiles_ints, + profile, + jnp.zeros_like(geo.rho), ) - - @chex.dataclass(frozen=True) - class DynamicFooRuntimeParams(runtime_params_lib.DynamicRuntimeParams): - foo_param: float - bar_param: float - - def _my_foo_model( - dynamic_runtime_params_slice, - dynamic_source_runtime_params, - geo, - core_profiles, - source_models, - ) -> jax.Array: - assert isinstance(dynamic_source_runtime_params, DynamicFooRuntimeParams) - # implement your foo model. - - @dataclasses.dataclass(kw_only=True) - class FooSource(SingleProfileSource): - - # Provide a default set of params. - runtime_params: FooRuntimeParams = dataclasses.field( - default_factory=lambda: FooRuntimeParams( - foo_param={0.0: 10.0, 1.0: 20.0, 2.0: 35.0}, - bar_param: 1.234, - ) + else: + idx = jnp.argmax( + jnp.asarray(affected_core_profile_ints) == affected_core_profile ) - - # By default, FooSource's can be model-based or set to 0. - supported_modes: tuple[runtime_params_lib.Mode, ...] = ( - runtime_params_lib.Mode.ZERO, - runtime_params_lib.Mode.MODEL_BASED, - ) - - # Don't include model_func in the __init__ arguments and freeze it. - model_func: SourceProfileFunction = dataclasses.field( - init=False, - default_factory=lambda: _my_foo_model, + chex.assert_rank(profile, 2) + return jnp.where( + affected_core_profile in affected_core_profile_ints, + profile[idx, ...], + jnp.zeros_like(geo.rho), ) - """ - - # Don't include output_shape_getter in the __init__ arguments. - # Freeze this parameter so that it always outputs a single cell profile. - output_shape_getter: SourceOutputShapeFunction = dataclasses.field( - init=False, - default_factory=lambda: get_cell_profile_shape, - ) - - def get_value( - self, - dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, - geo: geometry.Geometry, - core_profiles: state.CoreProfiles | None = None, - ) -> jax.Array: - """Returns the profile for this source during one time step.""" - output_shape = self.output_shape_getter(geo) - profile = super().get_value( - dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_source_runtime_params, - geo=geo, - core_profiles=core_profiles, - ) - assert isinstance(profile, jax.Array) - chex.assert_rank(profile, 1) - chex.assert_shape(profile, output_shape) - return profile - - def get_source_profile_for_affected_core_profile( - self, - profile: chex.ArrayTree, - affected_core_profile: int, - geo: geometry.Geometry, - ) -> jax.Array: - return jnp.where( - affected_core_profile in self.affected_core_profiles_ints, - profile, - jnp.zeros_like(geo.rho), - ) class ProfileType(enum.Enum): @@ -438,10 +312,6 @@ def get_profile_shape(self, geo: geometry.Geometry) -> tuple[int, ...]: } return profile_type_to_len[self] - def get_zero_profile(self, geo: geometry.Geometry) -> jax.Array: - """Returns a source profile with all zeros.""" - return jnp.zeros(self.get_profile_shape(geo)) - # pytype bug: 'source_models.SourceModels' not treated as a forward ref # pytype: disable=name-error @@ -779,4 +649,3 @@ def build_source(self): SourceBuilder = make_source_builder(Source) -SingleProfileSourceBuilder = make_source_builder(SingleProfileSource) diff --git a/torax/sources/source_models.py b/torax/sources/source_models.py index b1e84ee9..6a40ad73 100644 --- a/torax/sources/source_models.py +++ b/torax/sources/source_models.py @@ -630,12 +630,16 @@ def ohmic_model_func( # OhmicHeatSource is a special case and defined here to avoid circular # dependencies, since it depends on the psi sources @dataclasses.dataclass(kw_only=True, frozen=True, eq=True) -class OhmicHeatSource(source_lib.SingleProfileSource): +class OhmicHeatSource(source_lib.Source): """Ohmic heat source for electron heat equation. Pohm = jtor * psidot /(2*pi*Rmaj), related to electric power formula P = IV. """ - + # output_shape_getter is removed from __init__ as it is fixed to this value. + output_shape_getter: source_lib.SourceOutputShapeFunction = dataclasses.field( + init=False, + default_factory=lambda: source_lib.get_cell_profile_shape, + ) # Users must pass in a pointer to the complete set of sources to this object. source_models: SourceModels diff --git a/torax/sources/tests/formulas.py b/torax/sources/tests/formulas.py index 36779fb5..aaec10c0 100644 --- a/torax/sources/tests/formulas.py +++ b/torax/sources/tests/formulas.py @@ -93,7 +93,7 @@ def test_custom_exponential_source_can_replace_puff_source(self): # Add the custom source to the source_models, but keep it turned off for the # first run. source_models_builder.source_builders[custom_source_name] = ( - source.SingleProfileSourceBuilder( + source.SourceBuilder( supported_modes=( runtime_params_lib.Mode.ZERO, runtime_params_lib.Mode.FORMULA_BASED, diff --git a/torax/sources/tests/source.py b/torax/sources/tests/source.py index 0771ccf0..b8529876 100644 --- a/torax/sources/tests/source.py +++ b/torax/sources/tests/source.py @@ -28,6 +28,13 @@ from torax.sources import source_models as source_models_lib +def get_zero_profile( + profile_type: source_lib.ProfileType, geo: geometry.Geometry, +) -> jax.Array: + """Returns a source profile with all zeros.""" + return jnp.zeros(profile_type.get_profile_shape(geo)) + + class SourceTest(parameterized.TestCase): """Tests for the base class Source.""" @@ -123,7 +130,6 @@ class MySource: def test_zero_profile_works_by_default(self): """The default source impl should support profiles with all zeros.""" source_builder = source_lib.SourceBuilder( - output_shape_getter=source_lib.get_cell_profile_shape, affected_core_profiles=(source_lib.AffectedCoreProfile.PSI,), ) source_models_builder = source_models_lib.SourceModelsBuilder( @@ -157,7 +163,7 @@ def test_zero_profile_works_by_default(self): ) np.testing.assert_allclose( profile, - source_lib.ProfileType.CELL.get_zero_profile(geo), + get_zero_profile(source_lib.ProfileType.CELL, geo), ) def test_unsupported_modes_raise_errors(self): @@ -167,7 +173,6 @@ def test_unsupported_modes_raise_errors(self): # Only support formula-based profiles. runtime_params_lib.Mode.FORMULA_BASED, ), - output_shape_getter=source_lib.get_cell_profile_shape, affected_core_profiles=(source_lib.AffectedCoreProfile.NE,), ) # But set the runtime params of the source to use ZERO as the mode. @@ -212,7 +217,6 @@ def test_defaults_output_zeros(self): runtime_params_lib.Mode.FORMULA_BASED, runtime_params_lib.Mode.PRESCRIBED, ), - output_shape_getter=source_lib.get_cell_profile_shape, affected_core_profiles=(source_lib.AffectedCoreProfile.NE,), ) source_models_builder = source_models_lib.SourceModelsBuilder( @@ -259,7 +263,7 @@ def test_defaults_output_zeros(self): ) np.testing.assert_allclose( profile, - source_lib.ProfileType.CELL.get_zero_profile(geo), + get_zero_profile(source_lib.ProfileType.CELL, geo), ) with self.subTest('formula'): dynamic_runtime_params_slice = runtime_params_slice.DynamicRuntimeParamsSliceProvider( @@ -284,7 +288,7 @@ def test_defaults_output_zeros(self): ) np.testing.assert_allclose( profile, - source_lib.ProfileType.CELL.get_zero_profile(geo), + get_zero_profile(source_lib.ProfileType.CELL, geo), ) with self.subTest('prescribed'): dynamic_runtime_params_slice = runtime_params_slice.DynamicRuntimeParamsSliceProvider( @@ -309,7 +313,7 @@ def test_defaults_output_zeros(self): ) np.testing.assert_allclose( profile, - source_lib.ProfileType.CELL.get_zero_profile(geo), + get_zero_profile(source_lib.ProfileType.CELL, geo), ) def test_overriding_default_formula(self): @@ -318,7 +322,6 @@ def test_overriding_default_formula(self): output_shape = source_lib.ProfileType.CELL.get_profile_shape(geo) expected_output = jnp.ones(output_shape) source_builder = source_lib.SourceBuilder( - output_shape_getter=lambda _0: output_shape, formula=lambda _0, _1, _2, _3, _4: expected_output, affected_core_profiles=( source_lib.AffectedCoreProfile.TEMP_ION, @@ -363,7 +366,6 @@ def test_overriding_model(self): expected_output = jnp.ones(output_shape) source_builder = source_lib.SourceBuilder( supported_modes=(runtime_params_lib.Mode.MODEL_BASED,), - output_shape_getter=lambda _0: output_shape, model_func=lambda _0, _1, _2, _3, _4: expected_output, affected_core_profiles=( source_lib.AffectedCoreProfile.TEMP_ION, @@ -410,7 +412,6 @@ def test_overriding_prescribed_values(self): # Create the source source_builder = source_lib.SourceBuilder( supported_modes=(runtime_params_lib.Mode.PRESCRIBED,), - output_shape_getter=lambda _0: output_shape, affected_core_profiles=( source_lib.AffectedCoreProfile.TEMP_ION, source_lib.AffectedCoreProfile.TEMP_EL, @@ -454,10 +455,14 @@ def test_overriding_prescribed_values(self): def test_retrieving_profile_for_affected_state(self): """Grabbing the correct profile works for all mesh state attributes.""" output_shape = (2, 4) # Some arbitrary shape. + + @dataclasses.dataclass(frozen=True) + class TestSource(source_lib.Source): + output_shape_getter = lambda _0: output_shape + profile = jnp.asarray([[1, 2, 3, 4], [5, 6, 7, 8]]) # from get_value() - source = source_lib.Source( + source = TestSource( supported_modes=(runtime_params_lib.Mode.MODEL_BASED,), - output_shape_getter=lambda _0: output_shape, model_func=lambda _0, _1, _2, _3, _4: profile, affected_core_profiles=( source_lib.AffectedCoreProfile.PSI, @@ -490,8 +495,8 @@ def test_custom_formula(self): """The user-specified formula should override the default formula.""" runtime_params = general_runtime_params.GeneralRuntimeParams() geo = geometry.build_circular_geometry(n_rho=5) - expected_output = jnp.ones(5) # 5 matches the geo. - source_builder = source_lib.SingleProfileSourceBuilder( + expected_output = jnp.ones((5)) # 5 matches the geo. + source_builder = source_lib.SourceBuilder( formula=lambda _0, _1, _2, _3, _4: expected_output, affected_core_profiles=(source_lib.AffectedCoreProfile.PSI,), ) @@ -525,52 +530,10 @@ def test_custom_formula(self): ) np.testing.assert_allclose(profile, expected_output) - def test_multiple_profiles_raises_error(self): - """A formula which outputs the wrong shape will raise an error.""" - source_builder = source_lib.SingleProfileSourceBuilder( - formula=lambda _0, _1, _2, _3, _4: jnp.ones((2, 5)), - affected_core_profiles=( - source_lib.AffectedCoreProfile.TEMP_ION, - source_lib.AffectedCoreProfile.NE, - ), - ) - source_builder.runtime_params.mode = runtime_params_lib.Mode.FORMULA_BASED - source_models_builder = source_models_lib.SourceModelsBuilder( - {'foo': source_builder}, - ) - source_models = source_models_builder() - source = source_models.sources['foo'] - runtime_params = general_runtime_params.GeneralRuntimeParams() - geo = geometry.build_circular_geometry(n_rho=5) - dynamic_runtime_params_slice = ( - runtime_params_slice.DynamicRuntimeParamsSliceProvider( - runtime_params, - sources=source_models_builder.runtime_params, - torax_mesh=geo.torax_mesh, - )( - t=runtime_params.numerics.t_initial, - ) - ) - core_profiles = core_profile_setters.initial_core_profiles( - dynamic_runtime_params_slice=dynamic_runtime_params_slice, - geo=geo, - # defaults are enough for this. - source_models=source_models, - ) - with self.assertRaises(AssertionError): - source.get_value( - dynamic_runtime_params_slice=dynamic_runtime_params_slice, - dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ - 'foo' - ], - geo=geo, - core_profiles=core_profiles, - ) - def test_retrieving_profile_for_affected_state(self): """Grabbing the correct profile works for all mesh state attributes.""" profile = jnp.asarray([1, 2, 3, 4]) # from get_value() - source = source_lib.SingleProfileSource( + source = source_lib.Source( supported_modes=(runtime_params_lib.Mode.MODEL_BASED,), model_func=lambda _0, _1, _2, _3, _4: profile, affected_core_profiles=(source_lib.AffectedCoreProfile.NE,), diff --git a/torax/sources/tests/source_models.py b/torax/sources/tests/source_models.py index da2c4171..3f4832f6 100644 --- a/torax/sources/tests/source_models.py +++ b/torax/sources/tests/source_models.py @@ -14,6 +14,8 @@ """Tests for SourceModels and functions computing the source profiles.""" +import dataclasses + from absl.testing import absltest from absl.testing import parameterized import jax @@ -30,6 +32,14 @@ from torax.sources import source_profiles as source_profiles_lib +@dataclasses.dataclass(frozen=True) +class FooSource(source_lib.Source): + """A test source.""" + output_shape_getter = lambda: source_lib.get_ion_el_output_shape + +FooSourceBuilder = source_lib.make_source_builder(FooSource,) + + class SourceModelsTest(parameterized.TestCase): """Tests for SourceModels.""" @@ -140,18 +150,12 @@ def foo_formula( jnp.ones(source_lib.ProfileType.CELL.get_profile_shape(geo)), ]) - foo_source_builder = source_lib.SourceBuilder( - # Test a fake source that somehow affects both electron temp and - # electron density. + foo_source_builder = FooSourceBuilder( affected_core_profiles=( source_lib.AffectedCoreProfile.TEMP_EL, source_lib.AffectedCoreProfile.NE, ), supported_modes=(runtime_params_lib.Mode.FORMULA_BASED,), - output_shape_getter=( - lambda geo: (2,) - + source_lib.ProfileType.CELL.get_profile_shape(geo) - ), formula=foo_formula, ) # Set the source mode to FORMULA. diff --git a/torax/sources/tests/test_lib.py b/torax/sources/tests/test_lib.py index a2e0344b..00a41348 100644 --- a/torax/sources/tests/test_lib.py +++ b/torax/sources/tests/test_lib.py @@ -93,7 +93,7 @@ def test_source_value(self): source_models = source_models_builder() source = source_models.sources['foo'] source_builder.runtime_params.mode = source.supported_modes[0] - self.assertIsInstance(source, source_lib.SingleProfileSource) + self.assertIsInstance(source, source_lib.Source) geo = geometry.build_circular_geometry() dynamic_runtime_params_slice = ( runtime_params_slice.DynamicRuntimeParamsSliceProvider( @@ -131,7 +131,7 @@ def test_invalid_source_types_raise_errors(self): ) source_models = source_models_builder() source = source_models.sources['foo'] - self.assertIsInstance(source, source_lib.SingleProfileSource) + self.assertIsInstance(source, source_lib.Source) dynamic_runtime_params_slice = ( runtime_params_slice.DynamicRuntimeParamsSliceProvider( runtime_params=runtime_params, diff --git a/torax/tests/sim_custom_sources.py b/torax/tests/sim_custom_sources.py index 3b7db8c7..0223bfa2 100644 --- a/torax/tests/sim_custom_sources.py +++ b/torax/tests/sim_custom_sources.py @@ -142,7 +142,7 @@ def custom_source_formula( # Add the custom source with the correct params, but keep it turned off to # start. source_models_builder.source_builders[custom_source_name] = ( - source.SingleProfileSourceBuilder( + source.SourceBuilder( supported_modes=( runtime_params_lib.Mode.ZERO, runtime_params_lib.Mode.FORMULA_BASED, diff --git a/torax/tests/sim_output_source_profiles.py b/torax/tests/sim_output_source_profiles.py index 4b230691..a466c8a9 100644 --- a/torax/tests/sim_output_source_profiles.py +++ b/torax/tests/sim_output_source_profiles.py @@ -109,7 +109,7 @@ def custom_source_formula( # Include 2 versions of this source, one implicit and one explicit. source_models_builder = source_models_lib.SourceModelsBuilder({ - 'implicit_ne_source': source.SingleProfileSourceBuilder( + 'implicit_ne_source': source.SourceBuilder( supported_modes=( runtime_params_lib.Mode.ZERO, runtime_params_lib.Mode.FORMULA_BASED, @@ -121,7 +121,7 @@ def custom_source_formula( foo={0.0: 1.0, 1.0: 2.0, 2.0: 3.0, 3.0: 4.0}, ), ), - 'explicit_ne_source': source.SingleProfileSourceBuilder( + 'explicit_ne_source': source.SourceBuilder( supported_modes=( runtime_params_lib.Mode.ZERO, runtime_params_lib.Mode.FORMULA_BASED,