diff --git a/MaxText/tests/forward_pass_logit_checker.py b/MaxText/tests/forward_pass_logit_checker.py index 5557dbc93..a8035a9e2 100644 --- a/MaxText/tests/forward_pass_logit_checker.py +++ b/MaxText/tests/forward_pass_logit_checker.py @@ -119,10 +119,10 @@ def main(config, test_args): max_logging.log(f"{full_train_logits[0, 2, :]=}") token_size = int(test_args.token_size) if test_args.token_size else golden_logits.shape[0] max_logging.log( - f"Max Numerical Difference {np.max(np.subtract(full_train_logits[0, :token_size, :], golden_logits[:token_size, :]))}" + f"Max Numerical Difference {np.max(np.subtract(full_train_logits[0, 0, :token_size, :], golden_logits[:token_size, :]))}" ) - model_probabilities = jax.nn.softmax(full_train_logits[0, :token_size, :], axis=-1) + model_probabilities = jax.nn.softmax(full_train_logits[0, 0, :token_size, :], axis=-1) golden_probabilities = jax.nn.softmax(golden_logits[:token_size, :], axis=-1) max_logging.log(f"{golden_probabilities[1]=}") @@ -139,7 +139,7 @@ def main(config, test_args): else: max_logging.log("Checking Numerical Differences between train logits and golden logits") assert jax.numpy.allclose( - full_train_logits[0, :token_size, :], + full_train_logits[0, 0, :token_size, :], golden_logits[:token_size, :], rtol=float(test_args.rtol), atol=float(test_args.atol),