Skip to content

Commit

Permalink
remove references to PINT
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 614659095
  • Loading branch information
Torax team committed Mar 14, 2024
1 parent 9bd9430 commit d9ca762
Show file tree
Hide file tree
Showing 13 changed files with 158 additions and 557 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ on: [push, workflow_dispatch]
jobs:
pytest-job:
runs-on: ubuntu-latest
timeout-minutes: 60
timeout-minutes: 80

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
Expand Down
34 changes: 17 additions & 17 deletions torax/fvm/tests/fvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,20 @@
from torax import fvm
from torax.fvm import implicit_solve_block
from torax.fvm import residual_and_loss
from torax.tests.test_lib import pint_ref
from torax.tests.test_lib import torax_refs


class FVMTest(pint_ref.ReferenceValueTest):
class FVMTest(torax_refs.ReferenceValueTest):
"""Unit tests for the `torax.fvm` module."""

@parameterized.parameters([
dict(references_getter=pint_ref.circular_references),
dict(references_getter=pint_ref.chease_references_Ip_from_chease),
dict(references_getter=pint_ref.chease_references_Ip_from_config),
dict(references_getter=torax_refs.circular_references),
dict(references_getter=torax_refs.chease_references_Ip_from_chease),
dict(references_getter=torax_refs.chease_references_Ip_from_config),
])
def test_face_grad(
self,
references_getter: Callable[[], pint_ref.References],
references_getter: Callable[[], torax_refs.References],
):
"""Test that CellVariable.face_grad matches reference values."""
references = references_getter()
Expand All @@ -49,13 +49,13 @@ def test_face_grad(
np.testing.assert_allclose(face_grad_jax, references.psi_face_grad)

@parameterized.parameters([
dict(references_getter=pint_ref.circular_references),
dict(references_getter=pint_ref.chease_references_Ip_from_chease),
dict(references_getter=pint_ref.chease_references_Ip_from_config),
dict(references_getter=torax_refs.circular_references),
dict(references_getter=torax_refs.chease_references_Ip_from_chease),
dict(references_getter=torax_refs.chease_references_Ip_from_config),
])
def test_underconstrained(
self,
references_getter: Callable[[], pint_ref.References],
references_getter: Callable[[], torax_refs.References],
):
"""Test that CellVariable raises for underconstrained problems."""
references = references_getter()
Expand All @@ -79,13 +79,13 @@ def test_underconstrained(
)

@parameterized.parameters([
dict(references_getter=pint_ref.circular_references),
dict(references_getter=pint_ref.chease_references_Ip_from_chease),
dict(references_getter=pint_ref.chease_references_Ip_from_config),
dict(references_getter=torax_refs.circular_references),
dict(references_getter=torax_refs.chease_references_Ip_from_chease),
dict(references_getter=torax_refs.chease_references_Ip_from_config),
])
def test_overconstrained(
self,
references_getter: Callable[[], pint_ref.References],
references_getter: Callable[[], torax_refs.References],
):
"""Test that CellVariable raises for overconstrained problems."""
references = references_getter()
Expand All @@ -109,15 +109,15 @@ def test_overconstrained(
@parameterized.parameters([
dict(
seed=20221114,
references_getter=pint_ref.circular_references,
references_getter=torax_refs.circular_references,
),
dict(
seed=20221114,
references_getter=pint_ref.chease_references_Ip_from_chease,
references_getter=torax_refs.chease_references_Ip_from_chease,
),
dict(
seed=20221114,
references_getter=pint_ref.chease_references_Ip_from_config,
references_getter=torax_refs.chease_references_Ip_from_config,
),
])
def test_face_grad_constraints(self, seed, references_getter):
Expand Down
7 changes: 3 additions & 4 deletions torax/initial_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def initial_currents(
johmform_face = (1 - geo.r_face_norm**2) ** config.nu
Cohm = Iohm * 1e6 / _trapz(johmform_face * geo.spr_face, geo.r_face)
johm_face = Cohm * johmform_face # ohmic current profile on face grid
johm = geometry.face_to_cell(johm_face) # TODO see if can be removed
johm = geometry.face_to_cell(johm_face)

# calculate "External" current profile (e.g. ECCD)
# form of external current on face grid
Expand All @@ -423,9 +423,8 @@ def initial_currents(
johm_hires = Cohm_hires * johmform_hires

# calculate "External" current profile (e.g. ECCD) on cell grid.
# TODO(b/323504363): Remove ad-hoc circular equilibrium and hires
# logic. Try doing something more similar to RAPTOR's analytical circular
# equilibrium.
# TODO(b/323504363): Replace ad-hoc circular equilibrium
# with more accurate analytical equilibrium
jext_hires = jext_source.jext_hires(
source_type=dynamic_config_slice.sources[jext_source.name].source_type,
dynamic_config_slice=dynamic_config_slice,
Expand Down
29 changes: 12 additions & 17 deletions torax/sources/tests/fusion_heat_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from torax.sources import source_config
from torax.sources import source_profiles
from torax.sources.tests import test_lib
from torax.tests.test_lib import pint_ref
from torax.tests.test_lib import torax_refs


class FusionHeatSourceTest(test_lib.IonElSourceTestCase):
Expand All @@ -46,12 +46,12 @@ def setUpClass(cls):
)

@parameterized.parameters([
dict(references_getter=pint_ref.circular_references),
dict(references_getter=pint_ref.chease_references_Ip_from_chease),
dict(references_getter=pint_ref.chease_references_Ip_from_config),
dict(references_getter=torax_refs.circular_references),
dict(references_getter=torax_refs.chease_references_Ip_from_chease),
dict(references_getter=torax_refs.chease_references_Ip_from_config),
])
def test_calc_fusion(
self, references_getter: Callable[[], pint_ref.References]
self, references_getter: Callable[[], torax_refs.References]
):
"""Compare `calc_fusion` function to a reference implementation."""
references = references_getter()
Expand All @@ -72,11 +72,11 @@ def test_calc_fusion(
nref,
)

def calculate_fusion(config, geo, profiles):
"""Reference implementation from pyntegrated_model."""
# pyntegrated_model doesn't follow Google style
def calculate_fusion(config, geo, state):
"""Reference implementation from PINT. We still use TORAX state here."""
# PINT doesn't follow Google style
# pylint:disable=invalid-name
T = profiles.Ti.faceValue()
T = state.temp_ion.face_value()
consts = constants.CONSTANTS

# P [W/m^3] = Efus *1/4 * n^2 * <sigma*v>.
Expand Down Expand Up @@ -105,20 +105,15 @@ def calculate_fusion(config, geo, profiles):
) # units of m^3/s

Pfus = (
Efus * 0.25 * (profiles.ni.faceValue() * config.nref) ** 2 * sigmav
Efus * 0.25 * (state.ni.face_value() * config.nref) ** 2 * sigmav
) # [W/m^3]
# Modification from raw pyntegrated_model: we use geo.r_face here,
# rather than a call to geo.rface(), which in pyntegrated_model is FiPy
# FaceVariable.
Ptot = np.trapz(Pfus * geo.vpr_face, geo.r_face) / 1e6 # [MW]

return Ptot

profiles = pint_ref.state_to_profiles(state)
fusion_pint = calculate_fusion(config, geo, state)

fusion_pyntegrated = calculate_fusion(config, geo, profiles)

np.testing.assert_allclose(fusion_jax, fusion_pyntegrated)
np.testing.assert_allclose(fusion_jax, fusion_pint)


if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion torax/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class Currents:
j_bootstrap: jax.Array
j_bootstrap_face: jax.Array
# pylint: disable=invalid-name
# Using PINT / physics notation naming convention
# Using physics notation naming convention
I_bootstrap: jax.Array
sigma: jax.Array

Expand Down
Loading

0 comments on commit d9ca762

Please sign in to comment.