Skip to content

Commit

Permalink
update forward_pass_logit_checker for new dimension
Browse files Browse the repository at this point in the history
  • Loading branch information
A9isha committed Jan 17, 2025
1 parent 6d8a42b commit 6a6b270
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions MaxText/tests/forward_pass_logit_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,12 @@ def main(config, test_args):
max_logging.log(f"{golden_logits[2]=}")
max_logging.log(f"{full_train_logits[0, 2, :]=}")
token_size = int(test_args.token_size) if test_args.token_size else golden_logits.shape[0]
# The ellipsis is used to currently support jax nightly versions newer than 1/9/2025 and stable tests. This can be simplified later
max_logging.log(
f"Max Numerical Difference {np.max(np.subtract(full_train_logits[0, :token_size, :], golden_logits[:token_size, :]))}"
f"Max Numerical Difference {np.max(np.subtract(full_train_logits[..., 0, :token_size, :], golden_logits[:token_size, :]))}"
)

model_probabilities = jax.nn.softmax(full_train_logits[0, :token_size, :], axis=-1)
model_probabilities = jax.nn.softmax(full_train_logits[..., 0, :token_size, :], axis=-1)
golden_probabilities = jax.nn.softmax(golden_logits[:token_size, :], axis=-1)

max_logging.log(f"{golden_probabilities[1]=}")
Expand All @@ -139,7 +140,7 @@ def main(config, test_args):
else:
max_logging.log("Checking Numerical Differences between train logits and golden logits")
assert jax.numpy.allclose(
full_train_logits[0, :token_size, :],
full_train_logits[..., 0, :token_size, :],
golden_logits[:token_size, :],
rtol=float(test_args.rtol),
atol=float(test_args.atol),
Expand Down

0 comments on commit 6a6b270

Please sign in to comment.