Skip to content

Commit

Permalink
Create functionality to register sources and remove default source se…
Browse files Browse the repository at this point in the history
…tting code.

This change enables functionality for sources to be registered without touching TORAX internals.

PiperOrigin-RevId: 681949209
  • Loading branch information
Nush395 authored and Torax team committed Oct 15, 2024
1 parent b3b3692 commit a068757
Show file tree
Hide file tree
Showing 15 changed files with 234 additions and 229 deletions.
10 changes: 5 additions & 5 deletions torax/config/build_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
from torax import sim as sim_lib
from torax.config import config_args
from torax.config import runtime_params as runtime_params_lib
from torax.sources import default_sources
from torax.sources import formula_config
from torax.sources import formulas
from torax.sources import register_source
from torax.sources import runtime_params as source_runtime_params_lib
from torax.sources import source as source_lib
from torax.sources import source_models as source_models_lib
Expand Down Expand Up @@ -360,9 +360,8 @@ def _build_single_source_builder_from_config(
source_config: dict[str, Any],
) -> source_lib.SourceBuilderProtocol:
"""Builds a source builder from the input config."""
runtime_params = default_sources.get_default_runtime_params(
source_name,
)
registered_source = register_source.get_registered_source(source_name)
runtime_params = registered_source.default_runtime_params_class()
# Update the defaults with the config provided.
source_config = copy.copy(source_config)
if 'mode' in source_config:
Expand Down Expand Up @@ -395,7 +394,8 @@ def _build_single_source_builder_from_config(
kwargs = {'runtime_params': runtime_params}
if formula is not None:
kwargs['formula'] = formula
return default_sources.get_source_builder_type(source_name)(**kwargs)

return registered_source.source_builder_class(**kwargs)


def build_transport_model_builder_from_config(
Expand Down
4 changes: 2 additions & 2 deletions torax/config/tests/runtime_params_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@
from torax.config import profile_conditions as profile_conditions_lib
from torax.config import runtime_params as general_runtime_params
from torax.config import runtime_params_slice as runtime_params_slice_lib
from torax.sources import default_sources
from torax.sources import electron_density_sources
from torax.sources import external_current_source
from torax.sources import formula_config
from torax.sources import runtime_params as sources_params_lib
from torax.sources.tests import test_lib
from torax.stepper import runtime_params as stepper_params_lib
from torax.transport_model import runtime_params as transport_params_lib

Expand Down Expand Up @@ -470,7 +470,7 @@ def test_update_dynamic_slice_provider_updates_sources(
):
"""Tests that the dynamic slice provider can be updated."""
runtime_params = general_runtime_params.GeneralRuntimeParams()
source_models_builder = default_sources.get_default_sources_builder()
source_models_builder = test_lib.get_default_sources_builder()
source_models_builder.runtime_params['jext'].Iext = 1.0
geo = geometry.build_circular_geometry(n_rho=4)
provider = runtime_params_slice_lib.DynamicRuntimeParamsSliceProvider(
Expand Down
8 changes: 4 additions & 4 deletions torax/fvm/tests/fvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@
from torax.fvm import cell_variable
from torax.fvm import implicit_solve_block
from torax.fvm import residual_and_loss
from torax.sources import default_sources
from torax.sources import runtime_params as source_runtime_params
from torax.sources import source_models as source_models_lib
from torax.sources.tests import test_lib
from torax.stepper import runtime_params as stepper_runtime_params
from torax.tests.test_lib import torax_refs
from torax.transport_model import constant as constant_transport_model
Expand Down Expand Up @@ -397,7 +397,7 @@ def test_nonlinear_solve_block_loss_minimum(
)
)
transport_model = transport_model_builder()
source_models_builder = default_sources.get_default_sources_builder()
source_models_builder = test_lib.get_default_sources_builder()
source_models_builder.runtime_params['qei_source'].Qei_mult = 0.0
source_models_builder.runtime_params['generic_ion_el_heat_source'].Ptot = (
0.0
Expand Down Expand Up @@ -528,7 +528,7 @@ def test_implicit_solve_block_uses_updated_boundary_conditions(self):
)
)
transport_model = transport_model_builder()
source_models_builder = default_sources.get_default_sources_builder()
source_models_builder = test_lib.get_default_sources_builder()
source_models_builder.runtime_params['qei_source'].Qei_mult = 0.0
source_models_builder.runtime_params['generic_ion_el_heat_source'].Ptot = (
0.0
Expand Down Expand Up @@ -664,7 +664,7 @@ def test_theta_residual_uses_updated_boundary_conditions(self):
)
)
transport_model = transport_model_builder()
source_models_builder = default_sources.get_default_sources_builder()
source_models_builder = test_lib.get_default_sources_builder()
source_models_builder.runtime_params['qei_source'].Qei_mult = 0.0
source_models_builder.runtime_params['generic_ion_el_heat_source'].Ptot = (
0.0
Expand Down
183 changes: 0 additions & 183 deletions torax/sources/default_sources.py

This file was deleted.

Loading

0 comments on commit a068757

Please sign in to comment.