Skip to content

Commit

Permalink
Update equinox to >=0.11.3.
Browse files Browse the repository at this point in the history
This was to fix breaking changes in recent versions of JAX (0.4.32 and 0.4.33) which are supported.

PiperOrigin-RevId: 682333073
  • Loading branch information
Nush395 authored and Torax team committed Oct 4, 2024
1 parent 1e6c284 commit 94c21db
Show file tree
Hide file tree
Showing 7 changed files with 10 additions and 14 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ dependencies = [
"numpy>=1.24.1",
"setuptools;python_version>='3.10'",
"chex>=0.1.85",
"equinox @ git+https://github.com/patrick-kidger/equinox@1e601672d38d2c4d483535070a3572d8e8508a20",
"equinox>=0.11.3",
"PyYAML>=6.0.1",
"xarray>=2023.12.0",
"netcdf4>=1.6.5,<1.7.1",
Expand Down
2 changes: 1 addition & 1 deletion torax/config/tests/runtime_params_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def test_wext_in_dynamic_runtime_params_cannot_be_negative(self):
)
np.testing.assert_allclose(jext.wext, 0.0)
# But negative values will cause an error.
with self.assertRaises(jax.lib.xla_client.XlaRuntimeError):
with self.assertRaises(RuntimeError):
dcs_provider(t=1.0,)

@parameterized.parameters(
Expand Down
3 changes: 1 addition & 2 deletions torax/sources/tests/external_current_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""Tests for external_current_source."""

from absl.testing import absltest
import jax
from torax import geometry
from torax.config import runtime_params as general_runtime_params
from torax.config import runtime_params_slice
Expand Down Expand Up @@ -79,7 +78,7 @@ def test_invalid_source_types_raise_errors(self):
source = source_builder()
for unsupported_mode in self._unsupported_modes:
with self.subTest(unsupported_mode.name):
with self.assertRaises(jax.lib.xla_client.XlaRuntimeError):
with self.assertRaises(RuntimeError):
source_builder.runtime_params.mode = unsupported_mode
dynamic_slice = runtime_params_slice.DynamicRuntimeParamsSliceProvider(
runtime_params,
Expand Down
3 changes: 1 addition & 2 deletions torax/sources/tests/qei_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import dataclasses
from absl.testing import absltest
import jax
from torax import core_profile_setters
from torax import geometry
from torax.config import runtime_params as general_runtime_params
Expand Down Expand Up @@ -106,7 +105,7 @@ def test_invalid_source_types_raise_errors(self):
)
for unsupported_mode in self._unsupported_modes:
with self.subTest(unsupported_mode.name):
with self.assertRaises(jax.lib.xla_client.XlaRuntimeError):
with self.assertRaises(RuntimeError):
dynamic_slice = runtime_params_slice.DynamicRuntimeParamsSliceProvider(
runtime_params,
sources={
Expand Down
3 changes: 1 addition & 2 deletions torax/sources/tests/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import dataclasses
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import numpy as jnp
import numpy as np
from torax import core_profile_setters
Expand Down Expand Up @@ -195,7 +194,7 @@ def test_unsupported_modes_raise_errors(self):
source_models=source_models,
)
# But calling requesting ZERO shouldn't work.
with self.assertRaises(jax.lib.xla_client.XlaRuntimeError):
with self.assertRaises(RuntimeError):
source.get_value(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[
Expand Down
4 changes: 2 additions & 2 deletions torax/sources/tests/test_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def test_invalid_source_types_raise_errors(self):
)
)
with self.subTest(unsupported_mode.name):
with self.assertRaises(jax.lib.xla_client.XlaRuntimeError):
with self.assertRaises(RuntimeError):
source.get_value(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[
Expand Down Expand Up @@ -248,7 +248,7 @@ def test_invalid_source_types_raise_errors(self):
)
)
with self.subTest(unsupported_mode.name):
with self.assertRaises(jax.lib.xla_client.XlaRuntimeError):
with self.assertRaises(RuntimeError):
source.get_value(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
dynamic_source_runtime_params=dynamic_runtime_params_slice.sources[
Expand Down
7 changes: 3 additions & 4 deletions torax/tests/jax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import numpy as jnp
from torax import jax_utils

Expand All @@ -29,8 +28,8 @@ def _should_error(self):
x = jnp.array(0)
cond = x == 0

with self.assertRaises(jax.lib.xla_extension.XlaRuntimeError):
x = jax_utils.error_if(x, cond, msg="")
with self.assertRaises(RuntimeError):
jax_utils.error_if(x, cond, msg="")

def _should_not_error(self):
"""Call error_if, expecting it to be disabled.
Expand All @@ -41,7 +40,7 @@ def _should_not_error(self):
x = jnp.array(0)
cond = x == 0

x = jax_utils.error_if(x, cond, msg="")
jax_utils.error_if(x, cond, msg="")

def test_enable_errors(self):
"""Test that jax_utils.enable_errors enables / disables errors."""
Expand Down

0 comments on commit 94c21db

Please sign in to comment.