diff --git a/MaxText/tests/forward_pass_logit_checker.py b/MaxText/tests/forward_pass_logit_checker.py index 5557dbc93..cd23748c1 100644 --- a/MaxText/tests/forward_pass_logit_checker.py +++ b/MaxText/tests/forward_pass_logit_checker.py @@ -118,11 +118,12 @@ def main(config, test_args): max_logging.log(f"{golden_logits[2]=}") max_logging.log(f"{full_train_logits[0, 2, :]=}") token_size = int(test_args.token_size) if test_args.token_size else golden_logits.shape[0] + # The ellipsis is used to currently support jax nightly versions newer than 1/9/2025 and stable tests. This can be simplified later max_logging.log( - f"Max Numerical Difference {np.max(np.subtract(full_train_logits[0, :token_size, :], golden_logits[:token_size, :]))}" + f"Max Numerical Difference {np.max(np.subtract(full_train_logits[..., 0, :token_size, :], golden_logits[:token_size, :]))}" ) - model_probabilities = jax.nn.softmax(full_train_logits[0, :token_size, :], axis=-1) + model_probabilities = jax.nn.softmax(full_train_logits[..., 0, :token_size, :], axis=-1) golden_probabilities = jax.nn.softmax(golden_logits[:token_size, :], axis=-1) max_logging.log(f"{golden_probabilities[1]=}") @@ -139,7 +140,7 @@ def main(config, test_args): else: max_logging.log("Checking Numerical Differences between train logits and golden logits") assert jax.numpy.allclose( - full_train_logits[0, :token_size, :], + full_train_logits[..., 0, :token_size, :], golden_logits[:token_size, :], rtol=float(test_args.rtol), atol=float(test_args.atol),