Skip to content

Commit

Permalink
Remove SingleProfileSource.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 681488167
  • Loading branch information
Nush395 authored and Torax team committed Oct 4, 2024
1 parent 4352bba commit 777a69f
Show file tree
Hide file tree
Showing 12 changed files with 100 additions and 240 deletions.
2 changes: 1 addition & 1 deletion torax/sources/bremsstrahlung_heat_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def bremsstrahlung_model_func(


@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
class BremsstrahlungHeatSink(source.SingleProfileSource):
class BremsstrahlungHeatSink(source.Source):
"""Fusion heat source for both ion and electron heat."""
supported_modes: tuple[runtime_params_lib.Mode, ...] = (
runtime_params_lib.Mode.ZERO,
Expand Down
8 changes: 4 additions & 4 deletions torax/sources/current_density_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,31 +31,31 @@


@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
class ECRHCurrentSource(source.SingleProfileSource):
class ECRHCurrentSource(source.Source):
"""ECRH current density source for the psi equation."""
affected_core_profiles: tuple[source.AffectedCoreProfile, ...] = (
source.AffectedCoreProfile.PSI,
)


@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
class ICRHCurrentSource(source.SingleProfileSource):
class ICRHCurrentSource(source.Source):
"""ICRH current density source for the psi equation."""
affected_core_profiles: tuple[source.AffectedCoreProfile, ...] = (
source.AffectedCoreProfile.PSI,
)


@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
class LHCurrentSource(source.SingleProfileSource):
class LHCurrentSource(source.Source):
"""LH current density source for the psi equation."""
affected_core_profiles: tuple[source.AffectedCoreProfile, ...] = (
source.AffectedCoreProfile.PSI,
)


@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
class NBICurrentSource(source.SingleProfileSource):
class NBICurrentSource(source.Source):
"""NBI current density source for the psi equation."""
affected_core_profiles: tuple[source.AffectedCoreProfile, ...] = (
source.AffectedCoreProfile.PSI,
Expand Down
28 changes: 24 additions & 4 deletions torax/sources/electron_density_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,13 @@ def _calc_puff_source(


@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
class GasPuffSource(source.SingleProfileSource):
class GasPuffSource(source.Source):
"""Gas puff source for the ne equation."""
# output_shape_getter is removed from __init__ as it is fixed to this value.
output_shape_getter: source.SourceOutputShapeFunction = dataclasses.field(
init=False,
default_factory=lambda: source.get_cell_profile_shape,
)
formula: source.SourceProfileFunction = _calc_puff_source
affected_core_profiles: tuple[source.AffectedCoreProfile, ...] = (
source.AffectedCoreProfile.NE,
Expand Down Expand Up @@ -183,8 +188,13 @@ def _calc_nbi_source(


@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
class NBIParticleSource(source.SingleProfileSource):
class NBIParticleSource(source.Source):
"""Neutral-beam injection source for the ne equation."""
# output_shape_getter is removed from __init__ as it is fixed to this value.
output_shape_getter: source.SourceOutputShapeFunction = dataclasses.field(
init=False,
default_factory=lambda: source.get_cell_profile_shape,
)
formula: source.SourceProfileFunction = _calc_nbi_source
affected_core_profiles: tuple[source.AffectedCoreProfile, ...] = (
source.AffectedCoreProfile.NE,
Expand Down Expand Up @@ -256,8 +266,13 @@ def _calc_pellet_source(


@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
class PelletSource(source.SingleProfileSource):
class PelletSource(source.Source):
"""Pellet source for the ne equation."""
# output_shape_getter is removed from __init__ as it is fixed to this value.
output_shape_getter: source.SourceOutputShapeFunction = dataclasses.field(
init=False,
default_factory=lambda: source.get_cell_profile_shape,
)
formula: source.SourceProfileFunction = _calc_pellet_source
affected_core_profiles: tuple[source.AffectedCoreProfile, ...] = (
source.AffectedCoreProfile.NE,
Expand All @@ -274,8 +289,13 @@ class PelletSource(source.SingleProfileSource):


@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
class RecombinationDensitySink(source.SingleProfileSource):
class RecombinationDensitySink(source.Source):
"""Recombination sink for the electron density equation."""
# output_shape_getter is removed from __init__ as it is fixed to this value.
output_shape_getter: source.SourceOutputShapeFunction = dataclasses.field(
init=False,
default_factory=lambda: source.get_cell_profile_shape,
)
affected_core_profiles: tuple[source.AffectedCoreProfile, ...] = (
source.AffectedCoreProfile.NE,
)
Expand Down
18 changes: 9 additions & 9 deletions torax/sources/ion_el_heat_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,71 +31,71 @@


@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
class ChargeExchangeHeatSink(source.SingleProfileSource):
class ChargeExchangeHeatSink(source.Source):
"""Charge exchange loss term for the ion temp equation."""
affected_core_profiles: tuple[source.AffectedCoreProfile, ...] = (
source.AffectedCoreProfile.TEMP_ION,
)


@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
class CyclotronRadiationHeatSink(source.SingleProfileSource):
class CyclotronRadiationHeatSink(source.Source):
"""Cyclotron radiation loss term for the electron temp equation."""
affected_core_profiles: tuple[source.AffectedCoreProfile, ...] = (
source.AffectedCoreProfile.TEMP_EL,
)


@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
class ECRHHeatSource(source.SingleProfileSource):
class ECRHHeatSource(source.Source):
"""ECRH heat source for the electron temp equation."""
affected_core_profiles: tuple[source.AffectedCoreProfile, ...] = (
source.AffectedCoreProfile.TEMP_EL,
)


@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
class ICRHHeatSource(source.SingleProfileSource):
class ICRHHeatSource(source.Source):
"""ICRH heat source for the ion temp equation."""
affected_core_profiles: tuple[source.AffectedCoreProfile, ...] = (
source.AffectedCoreProfile.TEMP_ION,
)


@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
class LHHeatSource(source.SingleProfileSource):
class LHHeatSource(source.Source):
"""LH heat source for the electron temp equation."""
affected_core_profiles: tuple[source.AffectedCoreProfile, ...] = (
source.AffectedCoreProfile.TEMP_EL,
)


@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
class LineRadiationHeatSink(source.SingleProfileSource):
class LineRadiationHeatSink(source.Source):
"""Line radiation loss sink for the electron temp equation."""
affected_core_profiles: tuple[source.AffectedCoreProfile, ...] = (
source.AffectedCoreProfile.TEMP_EL,
)


@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
class NBIElectronHeatSource(source.SingleProfileSource):
class NBIElectronHeatSource(source.Source):
"""NBI heat source for the electron temp equation."""
affected_core_profiles: tuple[source.AffectedCoreProfile, ...] = (
source.AffectedCoreProfile.TEMP_EL,
)


@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
class NBIIonHeatSource(source.SingleProfileSource):
class NBIIonHeatSource(source.Source):
"""NBI heat source for the ion temp equation."""
affected_core_profiles: tuple[source.AffectedCoreProfile, ...] = (
source.AffectedCoreProfile.TEMP_ION,
)


@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
class RecombinationHeatSink(source.SingleProfileSource):
class RecombinationHeatSink(source.Source):
"""Recombination loss sink for the electron temp equation."""
affected_core_profiles: tuple[source.AffectedCoreProfile, ...] = (
source.AffectedCoreProfile.TEMP_EL,
Expand Down
169 changes: 19 additions & 150 deletions torax/sources/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,12 @@ class Source:
runtime_params_lib.Mode.PRESCRIBED,
)

output_shape_getter: SourceOutputShapeFunction = get_cell_profile_shape
# output_shape_getter is removed from __init__ as it is fixed to this value.
# For different output shapes, override this attribute.
output_shape_getter: SourceOutputShapeFunction = dataclasses.field(
init=False,
default_factory=lambda: get_cell_profile_shape,
)

model_func: SourceProfileFunction | None = None

Expand Down Expand Up @@ -272,153 +277,22 @@ def get_source_profile_for_affected_core_profile(
"""
# Get a valid index that defaults to 0 if not present.
affected_core_profile_ints = self.affected_core_profiles_ints
idx = jnp.argmax(
jnp.asarray(affected_core_profile_ints) == affected_core_profile
)
return jnp.where(
affected_core_profile in affected_core_profile_ints,
profile[idx, ...],
jnp.zeros_like(geo.rho),
)


@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
class SingleProfileSource(Source):
"""Source providing a single output profile on the cell grid.
Most sources in TORAX are instances (or subclasses) of this class.
You can define custom sources inline when constructing the full list of
sources to use in TORAX.
.. code-block:: python
# Define an electron-density source with a Gaussian profile.
my_custom_source_builder = source.SingleProfileSourceBuilder(
supported_modes=(
runtime_params_lib.Mode.ZERO,
runtime_params_lib.Mode.FORMULA_BASED,
),
affected_core_profiles=[source.AffectedCoreProfile.NE],
formula=formulas.Gaussian(my_custom_source_name),
)
# Define its runtime parameters (this could be done in the constructor as
# well).
my_custom_source_builder.runtime_params = runtime_params_lib.RuntimeParams(
mode=runtime_params_lib.Mode.FORMULA_BASED,
formula=formula_config.Gaussian(
total=1.0,
c1=2.0,
c2=3.0,
),
)
all_torax_sources_builder = source_models_lib.SourceModelsBuilder(
sources_builder={
'my_custom_source': my_custom_source_builder,
}
)
If you want to create a subclass of SingleProfileSource with frozen
parameters, you can provide default implementations/attributes. This is an
example of a model-based source with a frozen custom model that cannot be
changed by a runtime_params, along with custom runtime parameters specific to
this
source:
.. code-block:: python
@dataclasses.dataclass(kw_only=True)
class FooRuntimeParams(runtime_params_lib.RuntimeParams):
foo_param: runtime_params_lib.TimeInterpolatedInput
bar_param: float
def (build_dynamic_params(self, t: chex.Numeric)
-> DynamicFooRuntimeParams):
return DynamicFooRuntimeParams(
**config_args.get_init_kwargs(
input_config=self,
output_type=DynamicFooRuntimeParams,
t=t,
)
if len(affected_core_profile_ints) == 1:
return jnp.where(
affected_core_profile in self.affected_core_profiles_ints,
profile,
jnp.zeros_like(geo.rho),
)
@chex.dataclass(frozen=True)
class DynamicFooRuntimeParams(runtime_params_lib.DynamicRuntimeParams):
foo_param: float
bar_param: float
def _my_foo_model(
dynamic_runtime_params_slice,
dynamic_source_runtime_params,
geo,
core_profiles,
source_models,
) -> jax.Array:
assert isinstance(dynamic_source_runtime_params, DynamicFooRuntimeParams)
# implement your foo model.
@dataclasses.dataclass(kw_only=True)
class FooSource(SingleProfileSource):
# Provide a default set of params.
runtime_params: FooRuntimeParams = dataclasses.field(
default_factory=lambda: FooRuntimeParams(
foo_param={0.0: 10.0, 1.0: 20.0, 2.0: 35.0},
bar_param: 1.234,
)
else:
idx = jnp.argmax(
jnp.asarray(affected_core_profile_ints) == affected_core_profile
)
# By default, FooSource's can be model-based or set to 0.
supported_modes: tuple[runtime_params_lib.Mode, ...] = (
runtime_params_lib.Mode.ZERO,
runtime_params_lib.Mode.MODEL_BASED,
)
# Don't include model_func in the __init__ arguments and freeze it.
model_func: SourceProfileFunction = dataclasses.field(
init=False,
default_factory=lambda: _my_foo_model,
chex.assert_rank(profile, 2)
return jnp.where(
affected_core_profile in affected_core_profile_ints,
profile[idx, ...],
jnp.zeros_like(geo.rho),
)
"""

# Don't include output_shape_getter in the __init__ arguments.
# Freeze this parameter so that it always outputs a single cell profile.
output_shape_getter: SourceOutputShapeFunction = dataclasses.field(
init=False,
default_factory=lambda: get_cell_profile_shape,
)

def get_value(
self,
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams,
geo: geometry.Geometry,
core_profiles: state.CoreProfiles | None = None,
) -> jax.Array:
"""Returns the profile for this source during one time step."""
output_shape = self.output_shape_getter(geo)
profile = super().get_value(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
dynamic_source_runtime_params=dynamic_source_runtime_params,
geo=geo,
core_profiles=core_profiles,
)
assert isinstance(profile, jax.Array)
chex.assert_rank(profile, 1)
chex.assert_shape(profile, output_shape)
return profile

def get_source_profile_for_affected_core_profile(
self,
profile: chex.ArrayTree,
affected_core_profile: int,
geo: geometry.Geometry,
) -> jax.Array:
return jnp.where(
affected_core_profile in self.affected_core_profiles_ints,
profile,
jnp.zeros_like(geo.rho),
)


class ProfileType(enum.Enum):
Expand All @@ -438,10 +312,6 @@ def get_profile_shape(self, geo: geometry.Geometry) -> tuple[int, ...]:
}
return profile_type_to_len[self]

def get_zero_profile(self, geo: geometry.Geometry) -> jax.Array:
"""Returns a source profile with all zeros."""
return jnp.zeros(self.get_profile_shape(geo))


# pytype bug: 'source_models.SourceModels' not treated as a forward ref
# pytype: disable=name-error
Expand Down Expand Up @@ -779,4 +649,3 @@ def build_source(self):


SourceBuilder = make_source_builder(Source)
SingleProfileSourceBuilder = make_source_builder(SingleProfileSource)
Loading

0 comments on commit 777a69f

Please sign in to comment.