diff --git a/pyproject.toml b/pyproject.toml index a33970e6..57e2e254 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ dependencies = [ "numpy>=1.24.1", "setuptools;python_version>='3.10'", "chex>=0.1.85", - "equinox @ git+https://github.com/patrick-kidger/equinox@1e601672d38d2c4d483535070a3572d8e8508a20", + "equinox>=0.11.3", "PyYAML>=6.0.1", "xarray>=2023.12.0", "netcdf4>=1.6.5,<1.7.1", diff --git a/torax/config/tests/runtime_params_slice.py b/torax/config/tests/runtime_params_slice.py index 34e8af64..844fd5d6 100644 --- a/torax/config/tests/runtime_params_slice.py +++ b/torax/config/tests/runtime_params_slice.py @@ -305,7 +305,7 @@ def test_wext_in_dynamic_runtime_params_cannot_be_negative(self): ) np.testing.assert_allclose(jext.wext, 0.0) # But negative values will cause an error. - with self.assertRaises(jax.lib.xla_client.XlaRuntimeError): + with self.assertRaises(RuntimeError): dcs_provider(t=1.0,) @parameterized.parameters( diff --git a/torax/sources/tests/external_current_source.py b/torax/sources/tests/external_current_source.py index afc54ab9..a357049d 100644 --- a/torax/sources/tests/external_current_source.py +++ b/torax/sources/tests/external_current_source.py @@ -15,7 +15,6 @@ """Tests for external_current_source.""" from absl.testing import absltest -import jax from torax import geometry from torax.config import runtime_params as general_runtime_params from torax.config import runtime_params_slice @@ -79,7 +78,7 @@ def test_invalid_source_types_raise_errors(self): source = source_builder() for unsupported_mode in self._unsupported_modes: with self.subTest(unsupported_mode.name): - with self.assertRaises(jax.lib.xla_client.XlaRuntimeError): + with self.assertRaises(RuntimeError): source_builder.runtime_params.mode = unsupported_mode dynamic_slice = runtime_params_slice.DynamicRuntimeParamsSliceProvider( runtime_params, diff --git a/torax/sources/tests/qei_source.py b/torax/sources/tests/qei_source.py index ef253863..da2b379d 100644 --- a/torax/sources/tests/qei_source.py +++ b/torax/sources/tests/qei_source.py @@ -16,7 +16,6 @@ import dataclasses from absl.testing import absltest -import jax from torax import core_profile_setters from torax import geometry from torax.config import runtime_params as general_runtime_params @@ -106,7 +105,7 @@ def test_invalid_source_types_raise_errors(self): ) for unsupported_mode in self._unsupported_modes: with self.subTest(unsupported_mode.name): - with self.assertRaises(jax.lib.xla_client.XlaRuntimeError): + with self.assertRaises(RuntimeError): dynamic_slice = runtime_params_slice.DynamicRuntimeParamsSliceProvider( runtime_params, sources={ diff --git a/torax/sources/tests/source.py b/torax/sources/tests/source.py index 3181480b..0771ccf0 100644 --- a/torax/sources/tests/source.py +++ b/torax/sources/tests/source.py @@ -17,7 +17,6 @@ import dataclasses from absl.testing import absltest from absl.testing import parameterized -import jax from jax import numpy as jnp import numpy as np from torax import core_profile_setters @@ -195,7 +194,7 @@ def test_unsupported_modes_raise_errors(self): source_models=source_models, ) # But calling requesting ZERO shouldn't work. - with self.assertRaises(jax.lib.xla_client.XlaRuntimeError): + with self.assertRaises(RuntimeError): source.get_value( dynamic_runtime_params_slice=dynamic_runtime_params_slice, dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ diff --git a/torax/sources/tests/test_lib.py b/torax/sources/tests/test_lib.py index d8b8829d..a2e0344b 100644 --- a/torax/sources/tests/test_lib.py +++ b/torax/sources/tests/test_lib.py @@ -158,7 +158,7 @@ def test_invalid_source_types_raise_errors(self): ) ) with self.subTest(unsupported_mode.name): - with self.assertRaises(jax.lib.xla_client.XlaRuntimeError): + with self.assertRaises(RuntimeError): source.get_value( dynamic_runtime_params_slice=dynamic_runtime_params_slice, dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ @@ -248,7 +248,7 @@ def test_invalid_source_types_raise_errors(self): ) ) with self.subTest(unsupported_mode.name): - with self.assertRaises(jax.lib.xla_client.XlaRuntimeError): + with self.assertRaises(RuntimeError): source.get_value( dynamic_runtime_params_slice=dynamic_runtime_params_slice, dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[ diff --git a/torax/tests/jax_utils.py b/torax/tests/jax_utils.py index 53ad9c1e..d9fa039d 100644 --- a/torax/tests/jax_utils.py +++ b/torax/tests/jax_utils.py @@ -16,7 +16,6 @@ from absl.testing import absltest from absl.testing import parameterized -import jax from jax import numpy as jnp from torax import jax_utils @@ -29,8 +28,8 @@ def _should_error(self): x = jnp.array(0) cond = x == 0 - with self.assertRaises(jax.lib.xla_extension.XlaRuntimeError): - x = jax_utils.error_if(x, cond, msg="") + with self.assertRaises(RuntimeError): + jax_utils.error_if(x, cond, msg="") def _should_not_error(self): """Call error_if, expecting it to be disabled. @@ -41,7 +40,7 @@ def _should_not_error(self): x = jnp.array(0) cond = x == 0 - x = jax_utils.error_if(x, cond, msg="") + jax_utils.error_if(x, cond, msg="") def test_enable_errors(self): """Test that jax_utils.enable_errors enables / disables errors."""