Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Several tests failing with AttributeError: module 'jax.api_util' has no attribute 'debug_info' #4585

Open
GaetanLepage opened this issue Feb 28, 2025 · 0 comments

Comments

@GaetanLepage
Copy link

Context: bumping flax to 0.10.4 on nixpkgs: NixOS/nixpkgs#385676

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): NixOS unstable
  • Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib: flax 0.10.4, jax 0.5.0, jaxlib 0.5.0
  • Python version: 3.12.9
  • GPU/TPU model and memory: None
  • CUDA version (if applicable):

Problem you have encountered:

Several tests now fail with AttributeError: module 'jax.api_util' has no attribute 'debug_info':

=========================== short test summary info ============================
FAILED tests/core/core_meta_test.py::MetaTest::test_scan_over_layers - AttributeError: module 'jax.api_util' has no attribute 'debug_info'
FAILED tests/core/core_scope_test.py::ScopeTest::test_lazy_init - AttributeError: module 'jax.api_util' has no attribute 'debug_info'
FAILED tests/core/core_scope_test.py::ScopeTest::test_lazy_init_fails_on_data_dependence - AttributeError: module 'jax.api_util' has no attribute 'debug_info'
FAILED tests/core/design/core_big_resnets_test.py::BigResnetTest::test_big_resnet - AttributeError: module 'jax.api_util' has no attribute 'debug_info'
FAILED tests/linen/linen_module_test.py::ModuleTest::test_lazy_init - AttributeError: module 'jax.api_util' has no attribute 'debug_info'
FAILED tests/linen/linen_module_test.py::ModuleTest::test_lazy_init_fails_on_data_dependence - AttributeError: module 'jax.api_util' has no attribute 'debug_info'
FAILED tests/linen/linen_meta_test.py::LinenMetaTest::test_pjit_scan_over_layers - AttributeError: module 'jax.api_util' has no attribute 'debug_info'
FAILED tests/linen/linen_recurrent_test.py::RNNTest::test_basic_seq_lengths - AttributeError: module 'jax.api_util' has no attribute 'debug_info'
FAILED tests/linen/linen_recurrent_test.py::RNNTest::test_numerical_equivalence - AttributeError: module 'jax.api_util' has no attribute 'debug_info'
FAILED tests/linen/linen_recurrent_test.py::RNNTest::test_numerical_equivalence_single_batch - AttributeError: module 'jax.api_util' has no attribute 'debug_info'
FAILED tests/linen/linen_recurrent_test.py::RNNTest::test_numerical_equivalence_single_batch_nn_scan - AttributeError: module 'jax.api_util' has no attribute 'debug_info'
FAILED tests/linen/linen_recurrent_test.py::RNNTest::test_numerical_equivalence_with_mask - AttributeError: module 'jax.api_util' has no attribute 'debug_info'
FAILED tests/linen/linen_recurrent_test.py::RNNTest::test_reverse - AttributeError: module 'jax.api_util' has no attribute 'debug_info'
FAILED tests/linen/linen_recurrent_test.py::RNNTest::test_reverse_but_keep_order - AttributeError: module 'jax.api_util' has no attribute 'debug_info'
FAILED tests/linen/linen_recurrent_test.py::RNNTest::test_rnn_basic_forward - AttributeError: module 'jax.api_util' has no attribute 'debug_info'
FAILED tests/linen/linen_recurrent_test.py::RNNTest::test_rnn_multiple_batch_dims - AttributeError: module 'jax.api_util' has no attribute 'debug_info'
FAILED tests/linen/linen_transforms_test.py::TransformTest::test_multi_method_class_transform - AttributeError: module 'jax.api_util' has no attribute 'debug_info'
FAILED tests/linen/linen_recurrent_test.py::RNNTest::test_rnn_time_major - AttributeError: module 'jax.api_util' has no attribute 'debug_info'
FAILED tests/linen/linen_recurrent_test.py::RNNTest::test_rnn_unroll - AttributeError: module 'jax.api_util' has no attribute 'debug_info'
FAILED tests/linen/linen_recurrent_test.py::RNNTest::test_rnn_with_spatial_dimensions - AttributeError: module 'jax.api_util' has no attribute 'debug_info'
FAILED tests/linen/linen_recurrent_test.py::BidirectionalTest::test_bidirectional - AttributeError: module 'jax.api_util' has no attribute 'debug_info'
FAILED tests/linen/linen_recurrent_test.py::BidirectionalTest::test_custom_merge_fn - AttributeError: module 'jax.api_util' has no attribute 'debug_info'
FAILED tests/linen/linen_recurrent_test.py::BidirectionalTest::test_return_carry - AttributeError: module 'jax.api_util' has no attribute 'debug_info'
FAILED tests/linen/linen_recurrent_test.py::BidirectionalTest::test_shared_cell - AttributeError: module 'jax.api_util' has no attribute 'debug_info'
FAILED tests/linen/linen_transforms_test.py::TransformTest::test_remat_scan - AttributeError: module 'jax.api_util' has no attribute 'debug_info'
FAILED tests/core/design/core_scan_test.py::ScanTest::test_scan_shared_params - AttributeError: module 'jax.api_util' has no attribute 'debug_info'
FAILED tests/core/design/core_scan_test.py::ScanTest::test_scan_unshared_params - AttributeError: module 'jax.api_util' has no attribute 'debug_info'
FAILED tests/linen/linen_transforms_test.py::TransformTest::test_same_key - AttributeError: module 'jax.api_util' has no attribute 'debug_info'
FAILED tests/linen/summary_test.py::SummaryTest::test_lifted_transform - AttributeError: module 'jax.api_util' has no attribute 'debug_info'
FAILED tests/linen/summary_test.py::SummaryTest::test_lifted_transform_no_rename - AttributeError: module 'jax.api_util' has no attribute 'debug_info'
FAILED tests/linen/linen_transforms_test.py::TransformTest::test_scan - AttributeError: module 'jax.api_util' has no attribute 'debug_info'
FAILED tests/linen/linen_transforms_test.py::TransformTest::test_scan_compact_count - AttributeError: module 'jax.api_util' has no attribute 'debug_info'
FAILED tests/linen/linen_transforms_test.py::TransformTest::test_scan_decorated - AttributeError: module 'jax.api_util' has no attribute 'debug_info'
FAILED tests/linen/linen_transforms_test.py::TransformTest::test_jit_scan_retracing_retracing scan - AttributeError: module 'jax.api_util' has no attribute 'debug_info'
FAILED tests/linen/linen_transforms_test.py::TransformTest::test_scan_negative_axes - AttributeError: module 'jax.api_util' has no attribute 'debug_info'
FAILED tests/linen/linen_transforms_test.py::TransformTest::test_scan_of_setup_parameter - AttributeError: module 'jax.api_util' has no attribute 'debug_info'
FAILED tests/linen/linen_transforms_test.py::TransformTest::test_toplevel_submodule_adoption_pytree_transform - AttributeError: module 'jax.api_util' has no attribute 'debug_info'
FAILED tests/linen/partitioning_test.py::PartitioningTest::test_scan_with_axes - AttributeError: module 'jax.api_util' has no attribute 'debug_info'
FAILED tests/nnx/nn/recurrent_test.py::TestRNN::test_rnn_equivalence_with_flax_linen - AttributeError: module 'jax.api_util' has no attribute 'debug_info'
================= 39 failed, 1997 passed, 5 skipped in 48.74s ==================

What you expected to happen:

All tests pass.

Logs, error messages, etc:

...
>     debug_info = jax.api_util.debug_info("flax scan", broadcast_body,
                                           (in_tree,), {})
E     AttributeError: module 'jax.api_util' has no attribute 'debug_info'
E     --------------------
E     For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

flax/core/axes_scan.py:159: AttributeError

Steps to reproduce:

$ pytest
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant