diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index b5c14f20..20aa423b 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -6,7 +6,7 @@ on: [push, workflow_dispatch] jobs: pytest-job: runs-on: ubuntu-latest - timeout-minutes: 60 + timeout-minutes: 80 concurrency: group: ${{ github.workflow }}-${{ github.ref }} diff --git a/torax/fvm/tests/fvm.py b/torax/fvm/tests/fvm.py index b313dca4..bab24c7b 100644 --- a/torax/fvm/tests/fvm.py +++ b/torax/fvm/tests/fvm.py @@ -26,20 +26,20 @@ from torax import fvm from torax.fvm import implicit_solve_block from torax.fvm import residual_and_loss -from torax.tests.test_lib import pint_ref +from torax.tests.test_lib import torax_refs -class FVMTest(pint_ref.ReferenceValueTest): +class FVMTest(torax_refs.ReferenceValueTest): """Unit tests for the `torax.fvm` module.""" @parameterized.parameters([ - dict(references_getter=pint_ref.circular_references), - dict(references_getter=pint_ref.chease_references_Ip_from_chease), - dict(references_getter=pint_ref.chease_references_Ip_from_config), + dict(references_getter=torax_refs.circular_references), + dict(references_getter=torax_refs.chease_references_Ip_from_chease), + dict(references_getter=torax_refs.chease_references_Ip_from_config), ]) def test_face_grad( self, - references_getter: Callable[[], pint_ref.References], + references_getter: Callable[[], torax_refs.References], ): """Test that CellVariable.face_grad matches reference values.""" references = references_getter() @@ -49,13 +49,13 @@ def test_face_grad( np.testing.assert_allclose(face_grad_jax, references.psi_face_grad) @parameterized.parameters([ - dict(references_getter=pint_ref.circular_references), - dict(references_getter=pint_ref.chease_references_Ip_from_chease), - dict(references_getter=pint_ref.chease_references_Ip_from_config), + dict(references_getter=torax_refs.circular_references), + dict(references_getter=torax_refs.chease_references_Ip_from_chease), + dict(references_getter=torax_refs.chease_references_Ip_from_config), ]) def test_underconstrained( self, - references_getter: Callable[[], pint_ref.References], + references_getter: Callable[[], torax_refs.References], ): """Test that CellVariable raises for underconstrained problems.""" references = references_getter() @@ -79,13 +79,13 @@ def test_underconstrained( ) @parameterized.parameters([ - dict(references_getter=pint_ref.circular_references), - dict(references_getter=pint_ref.chease_references_Ip_from_chease), - dict(references_getter=pint_ref.chease_references_Ip_from_config), + dict(references_getter=torax_refs.circular_references), + dict(references_getter=torax_refs.chease_references_Ip_from_chease), + dict(references_getter=torax_refs.chease_references_Ip_from_config), ]) def test_overconstrained( self, - references_getter: Callable[[], pint_ref.References], + references_getter: Callable[[], torax_refs.References], ): """Test that CellVariable raises for overconstrained problems.""" references = references_getter() @@ -109,15 +109,15 @@ def test_overconstrained( @parameterized.parameters([ dict( seed=20221114, - references_getter=pint_ref.circular_references, + references_getter=torax_refs.circular_references, ), dict( seed=20221114, - references_getter=pint_ref.chease_references_Ip_from_chease, + references_getter=torax_refs.chease_references_Ip_from_chease, ), dict( seed=20221114, - references_getter=pint_ref.chease_references_Ip_from_config, + references_getter=torax_refs.chease_references_Ip_from_config, ), ]) def test_face_grad_constraints(self, seed, references_getter): diff --git a/torax/initial_states.py b/torax/initial_states.py index 9afab22e..16b3e0a3 100644 --- a/torax/initial_states.py +++ b/torax/initial_states.py @@ -396,7 +396,7 @@ def initial_currents( johmform_face = (1 - geo.r_face_norm**2) ** config.nu Cohm = Iohm * 1e6 / _trapz(johmform_face * geo.spr_face, geo.r_face) johm_face = Cohm * johmform_face # ohmic current profile on face grid - johm = geometry.face_to_cell(johm_face) # TODO see if can be removed + johm = geometry.face_to_cell(johm_face) # calculate "External" current profile (e.g. ECCD) # form of external current on face grid @@ -423,9 +423,8 @@ def initial_currents( johm_hires = Cohm_hires * johmform_hires # calculate "External" current profile (e.g. ECCD) on cell grid. - # TODO(b/323504363): Remove ad-hoc circular equilibrium and hires - # logic. Try doing something more similar to RAPTOR's analytical circular - # equilibrium. + # TODO(b/323504363): Replace ad-hoc circular equilibrium + # with more accurate analytical equilibrium jext_hires = jext_source.jext_hires( source_type=dynamic_config_slice.sources[jext_source.name].source_type, dynamic_config_slice=dynamic_config_slice, diff --git a/torax/sources/tests/fusion_heat_source.py b/torax/sources/tests/fusion_heat_source.py index 93e07ef9..2e73bd3d 100644 --- a/torax/sources/tests/fusion_heat_source.py +++ b/torax/sources/tests/fusion_heat_source.py @@ -26,7 +26,7 @@ from torax.sources import source_config from torax.sources import source_profiles from torax.sources.tests import test_lib -from torax.tests.test_lib import pint_ref +from torax.tests.test_lib import torax_refs class FusionHeatSourceTest(test_lib.IonElSourceTestCase): @@ -46,12 +46,12 @@ def setUpClass(cls): ) @parameterized.parameters([ - dict(references_getter=pint_ref.circular_references), - dict(references_getter=pint_ref.chease_references_Ip_from_chease), - dict(references_getter=pint_ref.chease_references_Ip_from_config), + dict(references_getter=torax_refs.circular_references), + dict(references_getter=torax_refs.chease_references_Ip_from_chease), + dict(references_getter=torax_refs.chease_references_Ip_from_config), ]) def test_calc_fusion( - self, references_getter: Callable[[], pint_ref.References] + self, references_getter: Callable[[], torax_refs.References] ): """Compare `calc_fusion` function to a reference implementation.""" references = references_getter() @@ -72,11 +72,11 @@ def test_calc_fusion( nref, ) - def calculate_fusion(config, geo, profiles): - """Reference implementation from pyntegrated_model.""" - # pyntegrated_model doesn't follow Google style + def calculate_fusion(config, geo, state): + """Reference implementation from PINT. We still use TORAX state here.""" + # PINT doesn't follow Google style # pylint:disable=invalid-name - T = profiles.Ti.faceValue() + T = state.temp_ion.face_value() consts = constants.CONSTANTS # P [W/m^3] = Efus *1/4 * n^2 * . @@ -105,20 +105,15 @@ def calculate_fusion(config, geo, profiles): ) # units of m^3/s Pfus = ( - Efus * 0.25 * (profiles.ni.faceValue() * config.nref) ** 2 * sigmav + Efus * 0.25 * (state.ni.face_value() * config.nref) ** 2 * sigmav ) # [W/m^3] - # Modification from raw pyntegrated_model: we use geo.r_face here, - # rather than a call to geo.rface(), which in pyntegrated_model is FiPy - # FaceVariable. Ptot = np.trapz(Pfus * geo.vpr_face, geo.r_face) / 1e6 # [MW] return Ptot - profiles = pint_ref.state_to_profiles(state) + fusion_pint = calculate_fusion(config, geo, state) - fusion_pyntegrated = calculate_fusion(config, geo, profiles) - - np.testing.assert_allclose(fusion_jax, fusion_pyntegrated) + np.testing.assert_allclose(fusion_jax, fusion_pint) if __name__ == '__main__': diff --git a/torax/state.py b/torax/state.py index 618c075e..2af8ca4d 100644 --- a/torax/state.py +++ b/torax/state.py @@ -52,7 +52,7 @@ class Currents: j_bootstrap: jax.Array j_bootstrap_face: jax.Array # pylint: disable=invalid-name - # Using PINT / physics notation naming convention + # Using physics notation naming convention I_bootstrap: jax.Array sigma: jax.Array diff --git a/torax/tests/fipy.py b/torax/tests/fipy.py deleted file mode 100644 index 081754b4..00000000 --- a/torax/tests/fipy.py +++ /dev/null @@ -1,353 +0,0 @@ -# Copyright 2024 DeepMind Technologies Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Unit tests for torax.fvm.""" -from typing import Optional, Sequence - -from absl.testing import absltest -from absl.testing import parameterized -import fipy -from jax import numpy as jnp -import numpy as np -from torax import config as config_lib -from torax import config_slice -from torax import fvm -from torax.fvm import implicit_solve_block -from torax.tests.test_lib import pint_ref - - -class FiPyTest(pint_ref.ReferenceValueTest): - """Tests that torax.fvm matches FiPy.""" - - @parameterized.parameters([ - dict(seed=202303151, left_grad=False, right_grad=False, dim=3), - dict(seed=202303152, left_grad=False, right_grad=True, dim=4), - dict(seed=202303153, left_grad=True, right_grad=False, dim=5), - dict(seed=202303154, left_grad=True, right_grad=True, dim=6), - ]) - def test_grad_and_face_grad( - self, seed: int, left_grad: bool, right_grad: bool, dim: int - ): - """Test that CellVariable.face_grad matches a FiPy equivalent. - - Args: - seed: Numpy RNG seed - left_grad: if True, use a gradient constraint on leftmost face, else use a - value constraint. - right_grad: if True, use a gradient constraint on rightmost face, else use - a value constraint. - dim: Size ofthe CellVariable - """ - - # Define the problem - rng = np.random.RandomState(seed) - value = rng.randn(dim) - eps = 1e-8 - dr = np.abs(rng.randn()) + eps - left_face_constraint = rng.randn() if not left_grad else None - left_face_grad_constraint = rng.randn() if left_grad else None - right_face_constraint = rng.randn() if not right_grad else None - right_face_grad_constraint = rng.randn() if right_grad else None - - # Torax solution - convert = lambda x: None if x is None else jnp.array(x) - cell_var_torax = fvm.CellVariable( - value=jnp.array(value), - dr=jnp.array(dr), - left_face_constraint=convert(left_face_constraint), - left_face_grad_constraint=convert(left_face_grad_constraint), - right_face_constraint=convert(right_face_constraint), - right_face_grad_constraint=convert(right_face_grad_constraint), - ) - grad_torax = cell_var_torax.grad() - face_grad_torax = cell_var_torax.face_grad() - - # FiPy solution - mesh = fipy.Grid1D( - nx=dim, - dx=dr, - ) - cell_var_fipy = fipy.CellVariable( - mesh=mesh, - value=value, - ) - if left_grad: - cell_var_fipy.faceGrad.constrain( - left_face_grad_constraint, where=mesh.facesLeft - ) - else: - cell_var_fipy.constrain(left_face_constraint, where=mesh.facesLeft) - if right_grad: - cell_var_fipy.faceGrad.constrain( - right_face_grad_constraint, where=mesh.facesRight - ) - else: - cell_var_fipy.constrain(right_face_constraint, where=mesh.facesRight) - grad_fipy = np.squeeze(cell_var_fipy.grad()) - face_grad_fipy = np.squeeze(cell_var_fipy.faceGrad()) - - # Check that the two solutions match - np.testing.assert_allclose(grad_torax, grad_fipy) - np.testing.assert_allclose(face_grad_torax, face_grad_fipy) - - @parameterized.parameters([ - dict(seed=202303155, left_grad=False, right_grad=False, dim=6), - dict(seed=202303156, left_grad=False, right_grad=True, dim=5), - dict(seed=202303157, left_grad=True, right_grad=False, dim=4), - dict(seed=202303158, left_grad=True, right_grad=True, dim=3), - ]) - def test_transient_diffusion( - self, - seed: int, - left_grad: bool, - right_grad: bool, - dim: int, - ): - """Test that implicit method with a transient term and diffusion term matches a FiPy equivalent. - - Args: - seed: Numpy RNG seed - left_grad: if True, use a gradient constraint on leftmost face, else use a - value constraint. - right_grad: if True, use a gradient constraint on rightmost face, else use - a value constraint. - dim: Size ofthe CellVariable - """ - - # Define the problem - rng = np.random.RandomState(seed) - init_x = rng.randn(dim) - eps = 1e-8 - dr = np.abs(rng.randn()) + eps - dt = np.abs(rng.randn()) + eps - tc_cell = np.abs(rng.randn(dim)) + eps - d_face = np.abs(rng.randn(dim + 1)) + eps - left_face_constraint = rng.randn() if not left_grad else None - left_face_grad_constraint = rng.randn() if left_grad else None - right_face_constraint = rng.randn() if not right_grad else None - right_face_grad_constraint = rng.randn() if right_grad else None - - # Torax solution - convert = lambda x: None if x is None else jnp.array(x) - init_x_torax = fvm.CellVariable( - value=jnp.array(init_x), - dr=jnp.array(dr), - left_face_constraint=convert(left_face_constraint), - left_face_grad_constraint=convert(left_face_grad_constraint), - right_face_constraint=convert(right_face_constraint), - right_face_grad_constraint=convert(right_face_grad_constraint), - ) - coeffs = fvm.Block1DCoeffs( - transient_out_cell=(tc_cell,), - transient_in_cell=(jnp.ones_like(tc_cell),), - d_face=(jnp.array(d_face),), - ) - config = config_lib.Config(nr=dim) - dynamic_config_slice = config_slice.build_dynamic_config_slice(config) - - (final_x_torax,), _ = implicit_solve_block.implicit_solve_block( - x_old=(init_x_torax,), - # Use the original x as the initial "guess" for x_new. Used when - # computing the coefficients for time t + dt. - x_new_vec_guess=init_x_torax.value, - x_new_update_fns=tuple([lambda cv: cv]), # no-op - dt=dt, - coeffs_old=coeffs, - # Assume no time-dependent params. - coeffs_callback=lambda x, dcs, allow_pereverzev=False: coeffs, - dynamic_config_slice_t_plus_dt=dynamic_config_slice, - theta_imp=1.0, - ) - # FiPy solution - mesh = fipy.Grid1D( - nx=dim, - dx=dr, - ) - assert init_x.ndim == 1 - x_fipy = fipy.CellVariable( - mesh=mesh, - value=init_x, - ) - if left_grad: - x_fipy.faceGrad.constrain(left_face_grad_constraint, where=mesh.facesLeft) - else: - x_fipy.constrain(left_face_constraint, where=mesh.facesLeft) - if right_grad: - x_fipy.faceGrad.constrain( - right_face_grad_constraint, where=mesh.facesRight - ) - else: - x_fipy.constrain(right_face_constraint, where=mesh.facesRight) - - transient_coeff = fipy.CellVariable(mesh=mesh, value=tc_cell) - transient = fipy.TransientTerm(coeff=transient_coeff, var=x_fipy) - diffusion_coeff = fipy.FaceVariable(mesh=mesh, value=d_face) - diffusion = fipy.DiffusionTerm(coeff=diffusion_coeff, var=x_fipy) - eq = transient == diffusion - eq.solve(dt=dt) - - # Check that the two solutions match - np.testing.assert_allclose(final_x_torax.value, x_fipy.value) - - @parameterized.parameters([ - dict(seed=202303161, left_grad=False, right_grad=False, dim=4), - dict(seed=202303162, left_grad=False, right_grad=True, dim=3), - dict(seed=202303163, left_grad=True, right_grad=False, dim=6), - dict(seed=202303164, left_grad=True, right_grad=True, dim=5), - dict( - seed=202303165, - left_grad=True, - right_grad=True, - dim=2, - ), - # Use d_face_mask to make sure we handle the zero diffusion corner - # case the same as FiPy. Zero diffusion requires some mild numerical - # hacks to avoid divide by zero. - dict( - seed=202304071, - left_grad=False, - right_grad=False, - dim=4, - d_face_mask=[1.0, 0.0, 0.0, 0.0, 1.0], - ), - dict( - seed=202304072, - left_grad=False, - right_grad=True, - dim=3, - d_face_mask=[0.0, 1.0, 1.0, 0.0], - ), - dict( - seed=202304073, - left_grad=True, - right_grad=False, - dim=6, - d_face_mask=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - ), - dict( - seed=202304074, - left_grad=True, - right_grad=True, - dim=5, - d_face_mask=[1.0, 0.0, 1.0, 0.0, 1.0, 0.0], - ), - ]) - def test_transient_diffusion_convection( - self, - seed: int, - left_grad: bool, - right_grad: bool, - dim: int, - d_face_mask: Optional[Sequence[float]] = None, - ): - """Test that implicit method with transient, diffusion and convection terms matches a FiPy equivalent. - - Args: - seed: Numpy RNG seed - left_grad: if True, use a gradient constraint on leftmost face, else use a - value constraint. - right_grad: if True, use a gradient constraint on rightmost face, else use - a value constraint. - dim: Size ofthe CellVariable - d_face_mask: Mask applied to `d_face` to test with zero diffusion - """ - - # Define the problem - rng = np.random.RandomState(seed) - init_x = rng.randn(dim) - eps = 1e-8 - dr = np.abs(rng.randn()) + eps - dt = np.abs(rng.randn()) + eps - tc_cell = np.abs(rng.randn(dim)) + eps - d_face = np.abs(rng.randn(dim + 1)) - if d_face_mask is not None: - d_face = d_face * np.array(d_face_mask) - v_face = rng.randn(dim + 1) - left_face_constraint = rng.randn() if not left_grad else None - left_face_grad_constraint = rng.randn() if left_grad else None - right_face_constraint = rng.randn() if not right_grad else None - right_face_grad_constraint = rng.randn() if right_grad else None - - # Torax solution - convert = lambda x: None if x is None else jnp.array(x) - init_x_torax = fvm.CellVariable( - value=jnp.array(init_x), - dr=jnp.array(dr), - left_face_constraint=convert(left_face_constraint), - left_face_grad_constraint=convert(left_face_grad_constraint), - right_face_constraint=convert(right_face_constraint), - right_face_grad_constraint=convert(right_face_grad_constraint), - ) - coeffs = fvm.Block1DCoeffs( - transient_out_cell=(tc_cell,), - transient_in_cell=(jnp.ones_like(tc_cell),), - d_face=(jnp.array(d_face),), - v_face=(jnp.array(v_face),), - ) - config = config_lib.Config(nr=dim) - dynamic_config_slice = config_slice.build_dynamic_config_slice(config) - (final_x_torax,), _ = implicit_solve_block.implicit_solve_block( - x_old=(init_x_torax,), - # Use the original x as the initial "guess" for x_new. Used when - # computing the coefficients for time t + dt. - x_new_vec_guess=init_x_torax.value, - x_new_update_fns=tuple([lambda cv: cv]), # no-op - dt=dt, - coeffs_old=coeffs, - # Assume no time-dependent params. - coeffs_callback=lambda x, dcs, allow_pereverzev=False: coeffs, - dynamic_config_slice_t_plus_dt=dynamic_config_slice, - theta_imp=1.0, - # Use FiPy's approach to convection boundary conditions - convection_dirichlet_mode="semi-implicit", - convection_neumann_mode="semi-implicit", - ) - # FiPy solution - mesh = fipy.Grid1D( - nx=dim, - dx=dr, - ) - assert init_x.ndim == 1 - x_fipy = fipy.CellVariable( - mesh=mesh, - value=init_x, - ) - if left_grad: - x_fipy.faceGrad.constrain(left_face_grad_constraint, where=mesh.facesLeft) - else: - x_fipy.constrain(left_face_constraint, where=mesh.facesLeft) - if right_grad: - x_fipy.faceGrad.constrain( - right_face_grad_constraint, where=mesh.facesRight - ) - else: - x_fipy.constrain(right_face_constraint, where=mesh.facesRight) - - transient_coeff = fipy.CellVariable(mesh=mesh, value=tc_cell) - transient = fipy.TransientTerm(coeff=transient_coeff, var=x_fipy) - diffusion_coeff = fipy.FaceVariable(mesh=mesh, value=d_face) - diffusion = fipy.DiffusionTerm(coeff=diffusion_coeff, var=x_fipy) - convection_coeff = fipy.FaceVariable( - mesh=mesh, value=np.expand_dims(v_face, 0) - ) - convection = fipy.ConvectionTerm(coeff=convection_coeff, var=x_fipy) - eq = transient + convection == diffusion - eq.solve(dt=dt) - - # Check that the two solutions match - np.testing.assert_allclose(final_x_torax.value, x_fipy.value) - - -if __name__ == "__main__": - absltest.main() diff --git a/torax/tests/geometry.py b/torax/tests/geometry.py index 6d7a5a7b..d46d7442 100644 --- a/torax/tests/geometry.py +++ b/torax/tests/geometry.py @@ -76,8 +76,6 @@ def foo(geo: geometry.Geometry): def face_to_cell(nr, face): - """Reference implementation from pyntegrated model.""" - cell = np.zeros(nr) cell[:] = 0.5 * (face[1:] + face[:-1]) return cell diff --git a/torax/tests/physics.py b/torax/tests/physics.py index a735b9dc..a9372ef2 100644 --- a/torax/tests/physics.py +++ b/torax/tests/physics.py @@ -24,19 +24,19 @@ from torax import initial_states from torax import physics from torax.sources import source_profiles -from torax.tests.test_lib import pint_ref +from torax.tests.test_lib import torax_refs -class PhysicsTest(pint_ref.ReferenceValueTest): +class PhysicsTest(torax_refs.ReferenceValueTest): """Unit tests for the `torax.physics` module.""" @parameterized.parameters([ - dict(references_getter=pint_ref.circular_references), - dict(references_getter=pint_ref.chease_references_Ip_from_chease), - dict(references_getter=pint_ref.chease_references_Ip_from_config), + dict(references_getter=torax_refs.circular_references), + dict(references_getter=torax_refs.chease_references_Ip_from_chease), + dict(references_getter=torax_refs.chease_references_Ip_from_config), ]) def test_calc_q_from_psi( - self, references_getter: Callable[[], pint_ref.References] + self, references_getter: Callable[[], torax_refs.References] ): """Compare `calc_q_from_psi` function to a reference implementation.""" references = references_getter() @@ -57,7 +57,7 @@ def test_calc_q_from_psi( # Make ground truth def calc_q_from_psi(config, geo): - """Reference implementation from pyntegrated model.""" + """Reference implementation from PINT.""" consts = constants.CONSTANTS iota = np.zeros(config.nr + 1) # on face grid q = np.zeros(config.nr + 1) # on face grid @@ -69,15 +69,13 @@ def calc_q_from_psi(config, geo): / (2 * np.pi * geo.B0 * geo.r_face[1:]) ) q[1:] = 1 / iota[1:] - # Change from pyntegrated model: we don't read jtot from `geo` + # Change from PINT: we don't read jtot from `geo` q[0] = ( 2 * geo.B0 / (consts.mu0 * jtot[0] * config.Rmaj) ) # use on-axis definition of q (Wesson 2004, Eq 3.48) q *= config.q_correction_factor def face_to_cell(config, face): - """Reference implementation from pyntegrated model.""" - cell = np.zeros(config.nr) cell[:] = 0.5 * (face[1:] + face[:-1]) return cell @@ -91,12 +89,12 @@ def face_to_cell(config, face): np.testing.assert_allclose(q_cell_jax, q_cell_np) @parameterized.parameters([ - dict(references_getter=pint_ref.circular_references), - dict(references_getter=pint_ref.chease_references_Ip_from_chease), - dict(references_getter=pint_ref.chease_references_Ip_from_config), + dict(references_getter=torax_refs.circular_references), + dict(references_getter=torax_refs.chease_references_Ip_from_chease), + dict(references_getter=torax_refs.chease_references_Ip_from_config), ]) def test_initial_psi( - self, references_getter: Callable[[], pint_ref.References] + self, references_getter: Callable[[], torax_refs.References] ): """Compare `initial_psi` function to a reference implementation.""" references = references_getter() @@ -120,12 +118,12 @@ def test_initial_psi( np.testing.assert_allclose(psi, references.psi.value) @parameterized.parameters([ - dict(references_getter=pint_ref.circular_references), - dict(references_getter=pint_ref.chease_references_Ip_from_chease), - dict(references_getter=pint_ref.chease_references_Ip_from_config), + dict(references_getter=torax_refs.circular_references), + dict(references_getter=torax_refs.chease_references_Ip_from_chease), + dict(references_getter=torax_refs.chease_references_Ip_from_config), ]) def test_calc_jtot_from_psi( - self, references_getter: Callable[[], pint_ref.References] + self, references_getter: Callable[[], torax_refs.References] ): """Compare `calc_jtot_from_psi` to a reference value.""" references = references_getter() @@ -139,12 +137,12 @@ def test_calc_jtot_from_psi( np.testing.assert_allclose(j, references.jtot) @parameterized.parameters([ - dict(references_getter=pint_ref.circular_references), - dict(references_getter=pint_ref.chease_references_Ip_from_chease), - dict(references_getter=pint_ref.chease_references_Ip_from_config), + dict(references_getter=torax_refs.circular_references), + dict(references_getter=torax_refs.chease_references_Ip_from_chease), + dict(references_getter=torax_refs.chease_references_Ip_from_config), ]) def test_calc_s_from_psi( - self, references_getter: Callable[[], pint_ref.References] + self, references_getter: Callable[[], torax_refs.References] ): """Compare `calc_s_from_psi` to a reference value.""" references = references_getter() diff --git a/torax/tests/sim.py b/torax/tests/sim.py index 1844e662..db7b0bf4 100644 --- a/torax/tests/sim.py +++ b/torax/tests/sim.py @@ -12,11 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Integration tests for Tokamak heat transport in JAX. +"""TORAX integration tests. These are full integration tests that run the simulation and compare to a -PINT reference: -https://gitlab.com/qualikiz-group/pyntegrated_model/-/tree/main/config_tests +previously executed TORAX reference: """ from typing import Optional, Sequence @@ -39,12 +38,14 @@ class SimTest(sim_test_case.SimTestCase): - """Integration tests for torax.sim.""" + """Integration tests for torax.sim. + + The numbering is legacy from when a subset of these tests were compared to + PINT runs. This numbering is kept to maintain backwards compatibility for now. + """ @parameterized.named_parameters( - # Where relevant we keep test names the same as in the PINT repo since - # the names are used to look up the reference files. - # See py files for test descriptions. + # Tests explicit solver ( 'test1', 'test1.py', @@ -52,7 +53,7 @@ class SimTest(sim_test_case.SimTestCase): _ALL_PROFILES, 0, ), - # Run test2 from the PINT repo, using Crank-Nicolson. + # Tests implicit solver with theta=0.5 (Crank-Nicholson) ( 'test2_cn', 'test2_cn.py', @@ -60,6 +61,7 @@ class SimTest(sim_test_case.SimTestCase): ('temp_ion', 'temp_el'), 2e-1, ), + # Tests implicit solver with theta=1.0 (backwards Euler) ( 'test2', 'test2.py', @@ -67,7 +69,7 @@ class SimTest(sim_test_case.SimTestCase): _ALL_PROFILES, 0, ), - # Make sure that the optimizer gets the same result as the linear solver + # Tests that optimizer gets the same result as the linear solver # when coefficients are frozen. ( 'test2_optimizer', @@ -76,7 +78,7 @@ class SimTest(sim_test_case.SimTestCase): _ALL_PROFILES, 1e-5, ), - # Make sure that Newton-Raphson gets the same result as the linear solver + # Tests that Newton-Raphson gets the same result as the linear solver # when the coefficient matrix is frozen ( 'test2_newton_raphson', @@ -85,6 +87,7 @@ class SimTest(sim_test_case.SimTestCase): _ALL_PROFILES, 1e-6, ), + # Test ion-electron heat exchange at low density ( 'test3', 'test3.py', @@ -92,7 +95,7 @@ class SimTest(sim_test_case.SimTestCase): _ALL_PROFILES, 0, ), - # test3_ref exercises sim.ArrayTimeStepCalculator + # Tests sim.ArrayTimeStepCalculator ( 'test3_ref', 'test3.py', @@ -101,6 +104,7 @@ class SimTest(sim_test_case.SimTestCase): 0, True, ), + # Tests ion-electron heat exchange at high density ( 'test4', 'test4.py', @@ -108,6 +112,7 @@ class SimTest(sim_test_case.SimTestCase): _ALL_PROFILES, 0, ), + # Tests pedestal internal boundary condition ( 'test5', 'test5.py', @@ -115,6 +120,7 @@ class SimTest(sim_test_case.SimTestCase): _ALL_PROFILES, 0, ), + # Tests CGM model heat transport only ( 'test6', 'test6.py', @@ -122,15 +128,7 @@ class SimTest(sim_test_case.SimTestCase): _ALL_PROFILES, 0, ), - # Test that we are able to reproduce FiPy's behavior in a case where - # FiPy is unstable - ( - 'test6_no_pedestal', - 'test6_no_pedestal.py', - 'test6_no_pedestal', - _ALL_PROFILES, - 1e-10, - ), + # Tests QLKNN model, heat transport only ( 'test7', 'test7.py', @@ -140,6 +138,7 @@ class SimTest(sim_test_case.SimTestCase): 1e-11, False, ), + # Tests fixed_dt timestep ( 'test7_fixed_dt', 'test7_fixed_dt.py', @@ -149,6 +148,7 @@ class SimTest(sim_test_case.SimTestCase): 1e-11, False, ), + # Tests current diffusion ( 'test8', 'test8.py', @@ -156,6 +156,7 @@ class SimTest(sim_test_case.SimTestCase): _ALL_PROFILES, 0, ), + # Tests combined current diffusion + heat transport with QLKNN ( 'test9', 'test9.py', @@ -163,7 +164,7 @@ class SimTest(sim_test_case.SimTestCase): _ALL_PROFILES, 0, ), - # Make sure that the optimizer gets the same result as the linear solver + # Tests that optimizer gets the same result as the linear solver # when using linear initial guess and 0 iterations. # Making sure to use a test involving Pereverzev-Corrigan for this, # since we do want it in the linear initial guess. @@ -174,7 +175,7 @@ class SimTest(sim_test_case.SimTestCase): _ALL_PROFILES, 0, ), - # Make sure that Newton-Raphson gets the same result as the linear solver + # Tests that Newton-Raphson gets the same result as the linear solver # when using linear initial guess and 0 iterations # Making sure to use a test involving Pereverzev-Corrigan for this, # since we do want it in the linear initial guess. @@ -185,6 +186,7 @@ class SimTest(sim_test_case.SimTestCase): _ALL_PROFILES, 0, ), + # Tests bootstrap current with heat+current-diffusion. CGM model ( 'test10', 'test10.py', @@ -192,6 +194,7 @@ class SimTest(sim_test_case.SimTestCase): _ALL_PROFILES, 0, ), + # Tests bootstrap current with heat+current-diffusion. QLKNN model ( 'test11', 'test11.py', @@ -199,6 +202,7 @@ class SimTest(sim_test_case.SimTestCase): _ALL_PROFILES, 0, ), + # Tests heat+current-diffusion+particle transport with constant transport ( 'test12', 'test12.py', @@ -206,6 +210,7 @@ class SimTest(sim_test_case.SimTestCase): _ALL_PROFILES, 0, ), + # Tests particle sources with constant transport. No NBI source ( 'test13', 'test13.py', @@ -213,6 +218,7 @@ class SimTest(sim_test_case.SimTestCase): _ALL_PROFILES, 0, ), + # Tests particle sources with CGM transport. No NBI source ( 'test14', 'test14.py', @@ -220,6 +226,7 @@ class SimTest(sim_test_case.SimTestCase): _ALL_PROFILES, 0, ), + # Tests all particle sources with CGM transport ( 'test15', 'test15.py', @@ -227,6 +234,7 @@ class SimTest(sim_test_case.SimTestCase): _ALL_PROFILES, 0, ), + # Tests density transport with QLKNN. De scaled from chi_e model ( 'test16', 'test16.py', @@ -235,6 +243,7 @@ class SimTest(sim_test_case.SimTestCase): 1e-3, 5e-4, ), + # Tests density transport with QLKNN. Deff+Veff model ( 'test17', 'test17.py', @@ -243,6 +252,7 @@ class SimTest(sim_test_case.SimTestCase): 1e-5, 2e-6, ), + # Tests fusion power. CGM transport, heat+particle+psi transport ( 'test18', 'test18.py', @@ -250,6 +260,7 @@ class SimTest(sim_test_case.SimTestCase): _ALL_PROFILES, 0, ), + # Tests fusion power. QLKNN transport, heat+particle+psi transport. ( 'test19', 'test19.py', @@ -258,6 +269,7 @@ class SimTest(sim_test_case.SimTestCase): 7e-5, 5e-4, ), + # Tests explicit solver. Ti only. CHEASE geometry. ( 'test20', 'test20.py', @@ -265,6 +277,7 @@ class SimTest(sim_test_case.SimTestCase): _ALL_PROFILES, 0, ), + # Tests implicit solver. Heat transport only. CHEASE geometry. ( 'test21', 'test21.py', @@ -272,6 +285,7 @@ class SimTest(sim_test_case.SimTestCase): _ALL_PROFILES, 0, ), + # Tests ion-electron heat exchange test at low density. CHEASE geometry. ( 'test22', 'test22.py', @@ -279,6 +293,7 @@ class SimTest(sim_test_case.SimTestCase): _ALL_PROFILES, 0, ), + # Tests Ohmic electron heat source. CHEASE geometry. ( 'test22_pohm', 'test22_pohm.py', @@ -286,6 +301,7 @@ class SimTest(sim_test_case.SimTestCase): _ALL_PROFILES, 0, ), + # Tests ion-electron heat exchange test at high density. CHEASE geometry. ( 'test23', 'test23.py', @@ -293,6 +309,7 @@ class SimTest(sim_test_case.SimTestCase): _ALL_PROFILES, 0, ), + # Tests pedestal internal boundary condition. CHEASE geometry. ( 'test24', 'test24.py', @@ -300,6 +317,7 @@ class SimTest(sim_test_case.SimTestCase): _ALL_PROFILES, 0, ), + # Tests CGM transport model. Heat transport only. CHEASE geometry. ( 'test25', 'test25.py', @@ -307,6 +325,7 @@ class SimTest(sim_test_case.SimTestCase): _ALL_PROFILES, 0, ), + # Tests QLKNN transport model. Heat transport only. CHEASE geometry. ( 'test26', 'test26.py', @@ -315,6 +334,7 @@ class SimTest(sim_test_case.SimTestCase): 1e-10, 1e-10, ), + # Tests current diffusion. CHEASE geometry. Ip from parameters. ( 'test27', 'test27.py', @@ -322,6 +342,7 @@ class SimTest(sim_test_case.SimTestCase): _ALL_PROFILES, 0, ), + # Tests current diffusion. CHEASE geometry. Ip from CHEASE. ( 'test28', 'test28.py', @@ -329,6 +350,7 @@ class SimTest(sim_test_case.SimTestCase): _ALL_PROFILES, 0, ), + # Tests combined heat+current-diffusion. CHEASE geometry. QLKNN. ( 'test29', 'test29.py', @@ -336,6 +358,7 @@ class SimTest(sim_test_case.SimTestCase): _ALL_PROFILES, 0, ), + # Tests time-dependent pedestal, Ptot, Ip. CHEASE geometry. QLKNN. ( 'test29_timedependent', 'test29_timedependent.py', @@ -343,6 +366,7 @@ class SimTest(sim_test_case.SimTestCase): _ALL_PROFILES, 0, ), + # Tests bootstrap current. CHEASE geometry. Heat+current-diffusion. CGM. ( 'test30', 'test30.py', @@ -350,6 +374,8 @@ class SimTest(sim_test_case.SimTestCase): _ALL_PROFILES, 0, ), + # Tests combined heat, particle, current-diffusion. + # CHEASE geometry. Constant transport ( 'test31', 'test31.py', @@ -357,6 +383,8 @@ class SimTest(sim_test_case.SimTestCase): _ALL_PROFILES, 0, ), + # Tests combined heat, particle, current-diffusion and pedestal. + # CHEASE geometry. Constant transport ( 'test32', 'test32.py', @@ -364,6 +392,7 @@ class SimTest(sim_test_case.SimTestCase): _ALL_PROFILES, 0, ), + # Tests particle sources. CHEASE geometry. No NBI. ( 'test33', 'test33.py', @@ -371,6 +400,7 @@ class SimTest(sim_test_case.SimTestCase): _ALL_PROFILES, 0, ), + # Tests particle sources. CHEASE geometry. No NBI. CGM + pedestal. ( 'test34', 'test34.py', @@ -378,6 +408,7 @@ class SimTest(sim_test_case.SimTestCase): _ALL_PROFILES, 0, ), + # Tests all particle sources. CHEASE geometry. CGM + pedestal. ( 'test35', 'test35.py', @@ -385,6 +416,7 @@ class SimTest(sim_test_case.SimTestCase): _ALL_PROFILES, 0, ), + # Tests particle transport with QLKNN. De scaled from chie. ( 'test36', 'test36.py', @@ -393,6 +425,7 @@ class SimTest(sim_test_case.SimTestCase): 1e-3, 6e-5, ), + # Tests particle transport with QLKNN. Deff+Veff model. ( 'test37', 'test37.py', @@ -401,6 +434,7 @@ class SimTest(sim_test_case.SimTestCase): 1e-4, 2e-6, ), + # Tests Crank-Nicholson with particle transport and QLKNN. Deff+Veff ( 'test37_theta05', 'test37_theta05.py', @@ -408,6 +442,8 @@ class SimTest(sim_test_case.SimTestCase): _ALL_PROFILES, 0, ), + # Tests fusion power. CHEASE geometry. Current, heat, particle transport. + # CGM transport model. ( 'test38', 'test38.py', @@ -415,6 +451,8 @@ class SimTest(sim_test_case.SimTestCase): _ALL_PROFILES, 0, ), + # Tests Pereverzev-Corrigan method for density. CHEASE geometry. QLKNN. + # De scaled from chie. ( 'test39', 'test39.py', @@ -423,6 +461,7 @@ class SimTest(sim_test_case.SimTestCase): 7e-5, 5e-5, ), + # Tests full integration for ITER-baseline-like config. ( 'test40', 'test40.py', @@ -431,6 +470,7 @@ class SimTest(sim_test_case.SimTestCase): 7e-5, 5e-5, ), + # Tests full integration for ITER-baseline-like config. Linear solver. ( 'test41', 'test41.py', @@ -439,6 +479,7 @@ class SimTest(sim_test_case.SimTestCase): 7e-5, 5e-5, ), + # Tests full integration for ITER-hybrid-like config. Linear solver. ( 'test42', 'test42.py', @@ -447,6 +488,8 @@ class SimTest(sim_test_case.SimTestCase): 7e-5, 5e-5, ), + # Tests full integration for ITER-hybrid-like config. + # Predictor-corrector solver. ( 'test42_predictor_corrector', 'test42_predictor_corrector.py', @@ -455,14 +498,16 @@ class SimTest(sim_test_case.SimTestCase): 7e-5, 5e-5, ), + # Tests TORAX regression of test42 ( 'test42_torax', 'test42.py', 'test42_torax', _ALL_PROFILES, - 1e-12, - 1e-12, + 1e-11, + 1e-11, ), + # Tests Newton-Raphson nonlinear solver for ITER-hybrid-like-config ( 'test42_nl_Hmode', 'test42_nl_Hmode.py', @@ -472,7 +517,7 @@ class SimTest(sim_test_case.SimTestCase): 1e-6, ), ) - def test_pyntegrated( + def test_torax_sim( self, config_name: str, ref_name: str, @@ -481,11 +526,11 @@ def test_pyntegrated( atol: Optional[float] = None, use_ref_time: bool = False, ): - """Integration test comparing to reference output from PINT or TORAX.""" - # The @parameterized decorator removes the `test_pyntegrated` method, + """Integration test comparing to reference output from TORAX.""" + # The @parameterized decorator removes the `test_torax_sim` method, # so we separate the actual functionality into a helper method that will # not be removed. - self._test_pyntegrated( + self._test_torax_sim( config_name, ref_name, profiles, @@ -499,7 +544,7 @@ def test_fail(self): # Run test3 but pass in the reference result from test2 with self.assertRaises(AssertionError): - self._test_pyntegrated( + self._test_torax_sim( 'test3.py', 'test2', ('temp_ion', 'temp_el'), @@ -571,8 +616,8 @@ def test_no_op(self): self.assertEqual(history_length, t.shape[0]) self.assertGreater(t[-1], config.t_final) - for pint_profile in _ALL_PROFILES: - profile_history = state_history[pint_profile] + for torax_profile in _ALL_PROFILES: + profile_history = state_history[torax_profile] # This is needed for CellVariable but not face variables if hasattr(profile_history, 'value'): profile_history = profile_history.value @@ -587,7 +632,7 @@ def test_no_op(self): msg = ( 'Profile changed over time despite all equations being ' 'disabled.\n' - f'Profile name: {pint_profile}\n' + f'Profile name: {torax_profile}\n' f'Initial value: {first_profile}\n' f'Failing time index: {i}\n' f'Failing value: {profile_history[i]}\n' diff --git a/torax/tests/sim_no_compile.py b/torax/tests/sim_no_compile.py index 52f67d83..794c88d3 100644 --- a/torax/tests/sim_no_compile.py +++ b/torax/tests/sim_no_compile.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests that Torax can be run with compilation disabled.""" +"""Tests that TORAX can be run with compilation disabled.""" from typing import Optional, Sequence @@ -50,7 +50,7 @@ class SimTest(sim_test_case.SimTestCase): False, ), ) - def test_pyntegrated( + def test_torax_sim( self, config_name: str, ref_name: str, @@ -62,7 +62,7 @@ def test_pyntegrated( """No-compilation version of integration tests.""" assert not jax_utils.env_bool('TORAX_COMPILATION_ENABLED', True) - self._test_pyntegrated( + self._test_torax_sim( config_name, ref_name, profiles, diff --git a/torax/tests/state.py b/torax/tests/state.py index df04f688..5752d601 100644 --- a/torax/tests/state.py +++ b/torax/tests/state.py @@ -26,10 +26,10 @@ from torax import geometry from torax import initial_states from torax import state as state_module -from torax.tests.test_lib import pint_ref +from torax.tests.test_lib import torax_refs -class StateTest(pint_ref.ReferenceValueTest): +class StateTest(torax_refs.ReferenceValueTest): """Unit tests for the `torax.state` module.""" def setUp(self): @@ -64,13 +64,13 @@ def make_history(config, geo): self._make_history = make_history @parameterized.parameters([ - dict(references_getter=pint_ref.circular_references), - dict(references_getter=pint_ref.chease_references_Ip_from_chease), - dict(references_getter=pint_ref.chease_references_Ip_from_config), + dict(references_getter=torax_refs.circular_references), + dict(references_getter=torax_refs.chease_references_Ip_from_chease), + dict(references_getter=torax_refs.chease_references_Ip_from_config), ]) def test_sanity_check( self, - references_getter: Callable[[], pint_ref.References], + references_getter: Callable[[], torax_refs.References], ): """Make sure State.sanity_check can be called.""" references = references_getter() @@ -81,13 +81,13 @@ def test_sanity_check( basic_state.sanity_check() @parameterized.parameters([ - dict(references_getter=pint_ref.circular_references), - dict(references_getter=pint_ref.chease_references_Ip_from_chease), - dict(references_getter=pint_ref.chease_references_Ip_from_config), + dict(references_getter=torax_refs.circular_references), + dict(references_getter=torax_refs.chease_references_Ip_from_chease), + dict(references_getter=torax_refs.chease_references_Ip_from_config), ]) def test_index( self, - references_getter: Callable[[], pint_ref.References], + references_getter: Callable[[], torax_refs.References], ): """Test State.index.""" references = references_getter() @@ -97,13 +97,13 @@ def test_index( self.assertEqual(i, history.index(i).temp_ion.value[0]) @parameterized.parameters([ - dict(references_getter=pint_ref.circular_references), - dict(references_getter=pint_ref.chease_references_Ip_from_chease), - dict(references_getter=pint_ref.chease_references_Ip_from_config), + dict(references_getter=torax_refs.circular_references), + dict(references_getter=torax_refs.chease_references_Ip_from_chease), + dict(references_getter=torax_refs.chease_references_Ip_from_config), ]) def test_project( self, - references_getter: Callable[[], pint_ref.References], + references_getter: Callable[[], torax_refs.References], ): """Test State.project.""" references = references_getter() diff --git a/torax/tests/test_lib/sim_test_case.py b/torax/tests/test_lib/sim_test_case.py index 771a313a..a132665c 100644 --- a/torax/tests/test_lib/sim_test_case.py +++ b/torax/tests/test_lib/sim_test_case.py @@ -229,7 +229,7 @@ def _check_profiles_vs_expected( raise AssertionError(final_msg) - def _test_pyntegrated( + def _test_torax_sim( self, config_name: str, ref_name: str, @@ -238,7 +238,7 @@ def _test_pyntegrated( atol: Optional[float] = None, use_ref_time: bool = False, ): - """Integration test comparing to reference output from PINT or TORAX. + """Integration test comparing to TORAX reference output. Args: config_name: Name of py config to load. (Leave off dir path, include @@ -249,9 +249,6 @@ def _test_pyntegrated( rtol: Optional float, to override the class level rtol. atol: Optional float, to override the class level atol. use_ref_time: If True, locks to time steps calculated by reference. - - Raises: - SkipTest: in the case of a known discrepancy with FiPy """ if rtol is None: diff --git a/torax/tests/test_lib/pint_ref.py b/torax/tests/test_lib/torax_refs.py similarity index 83% rename from torax/tests/test_lib/pint_ref.py rename to torax/tests/test_lib/torax_refs.py index 35765a95..736b8863 100644 --- a/torax/tests/test_lib/pint_ref.py +++ b/torax/tests/test_lib/torax_refs.py @@ -12,13 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Shared setup code for unit tests using reference values from PINT.""" +"""Shared setup code for unit tests using reference values.""" import os from absl.testing import absltest from absl.testing import parameterized import chex -import fipy from jax import numpy as jnp import numpy as np import torax @@ -44,12 +43,6 @@ class References: s: np.ndarray -# TODO(b/323504363): Eventually, we will want to get reference values and h5 -# files from Torax itself. We will want the ability to test against Torax for -# regression testing. This will include removing PINT-related flags for -# comparison. - - def circular_references() -> References: """Reference values for circular geometry.""" # Hard-code the parameters relevant to the tests, so the reference values @@ -75,7 +68,7 @@ def circular_references() -> References: kappa=1.72, hires_fac=4, ) - # ground truth values copied from an example PINT execution using + # ground truth values copied from example executions using # array.astype(str),which allows fully lossless reloading psi = fvm.CellVariable( value=jnp.array( @@ -109,7 +102,7 @@ def circular_references() -> References: ), right_face_grad_constraint=jnp.array(53.182574789531735), dr=geo.dr_norm, - ) # TODO revisit these tests after general geometry is done + ) psi_face_grad = np.array([ '0.0', '10.227178846458628', @@ -379,7 +372,7 @@ def chease_references_Ip_from_config() -> References: # pylint: disable=invalid geometry_file='ITER_hybrid_citrin_equil_cheasedata.mat2cols', Ip_from_parameters=True, ) - # ground truth values copied from an example PINT execution using + # ground truth values copied from an example executions using # array.astype(str),which allows fully lossless reloading psi = fvm.CellVariable( value=jnp.array( @@ -413,7 +406,7 @@ def chease_references_Ip_from_config() -> References: # pylint: disable=invalid ), right_face_grad_constraint=jnp.array(64.25482269382654), dr=geo.dr_norm, - ) # TODO revisit these tests after general geometry is done + ) psi_face_grad = np.array([ '0.0', '7.329120928506605', @@ -507,15 +500,13 @@ def chease_references_Ip_from_config() -> References: # pylint: disable=invalid ) -# TODO(b/323504363): Might be able to remove this test class completely since all -# the references are constants. class ReferenceValueTest(parameterized.TestCase): - """Unit using reference values from PINT.""" + """Unit using reference values from previous executions.""" def setUp(self): super().setUp() - # Some reference values from pyntegrated model are used in more than one - # test. These are loaded here. + # Some pre-calculated reference values are used in more than one test. + # These are loaded here. self.circular_references = circular_references() # pylint: disable=invalid-name self.chease_references_with_Ip_from_chease = ( @@ -526,74 +517,5 @@ def setUp(self): ) # pylint: enable=invalid-name - -def convert_cell_var_torax_to_fipy( - torax_var: torax.fvm.CellVariable, -) -> fipy.CellVariable: - """Convert a Torax CellVariable to a FiPy CellVariable. - - Args: - torax_var: The Torax variable to convert. - - Returns: - fipy_var: The FiPy equivalent of that variable. - """ - - mesh = fipy.Grid1D(nx=torax_var.value.shape[0], dx=torax_var.dr) - fipy_var = fipy.CellVariable(mesh=mesh, value=torax_var.value) - if torax_var.left_face_constraint is not None: - fipy_var.constrain(torax_var.left_face_constraint, mesh.facesLeft) - if torax_var.left_face_grad_constraint is not None: - fipy_var.faceGrad.constrain( - torax_var.left_face_grad_constraint, mesh.facesLeft - ) - if torax_var.right_face_constraint is not None: - fipy_var.constrain(torax_var.right_face_constraint, mesh.facesRight) - if torax_var.right_face_grad_constraint is not None: - fipy_var.faceGrad.constrain( - torax_var.right_face_grad_constraint, mesh.facesRight - ) - return fipy_var - - -class Profiles: - """A class analogous to State from Torax. - - This class is used to support some reference code from pyntegrated_model - used to provide ground truth for tests. - - Attributes: - Ti: Analogous to `temp_ion` from Torax's `State`. - ne: Analogous to `ne` from Torax's `State`. - ni: Analogous to `ni` from Torax's `State`. - """ - - # pyntegrated_model doesn't follow Google style - # pylint:disable=invalid-name - - def __init__( - self, Ti: fipy.CellVariable, ne: fipy.CellVariable, ni: fipy.CellVariable - ): - self.Ti = Ti - self.ne = ne - self.ni = ni - - -def state_to_profiles(state: torax.state.State): - """Converts a Torax `State` to a `Profiles` for use with pyntegrated model. - - Args: - state: The Torax State - - Returns: - profiles: The Profiles for use with pyntegrated model. - """ - return Profiles( - Ti=convert_cell_var_torax_to_fipy(state.temp_ion), - ne=convert_cell_var_torax_to_fipy(state.ne), - ni=convert_cell_var_torax_to_fipy(state.ni), - ) - - if __name__ == '__main__': absltest.main()