Skip to content

Commit

Permalink
update forward_pass_logit_checker for new dimension
Browse files Browse the repository at this point in the history
  • Loading branch information
A9isha committed Jan 17, 2025
1 parent 3ad02ba commit 5ec38ee
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions MaxText/tests/forward_pass_logit_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,10 @@ def main(config, test_args):
max_logging.log(f"{full_train_logits[0, 2, :]=}")
token_size = int(test_args.token_size) if test_args.token_size else golden_logits.shape[0]
max_logging.log(
f"Max Numerical Difference {np.max(np.subtract(full_train_logits[0, :token_size, :], golden_logits[:token_size, :]))}"
f"Max Numerical Difference {np.max(np.subtract(full_train_logits[0, 0, :token_size, :], golden_logits[:token_size, :]))}"
)

model_probabilities = jax.nn.softmax(full_train_logits[0, :token_size, :], axis=-1)
model_probabilities = jax.nn.softmax(full_train_logits[0, 0, :token_size, :], axis=-1)
golden_probabilities = jax.nn.softmax(golden_logits[:token_size, :], axis=-1)

max_logging.log(f"{golden_probabilities[1]=}")
Expand All @@ -139,7 +139,7 @@ def main(config, test_args):
else:
max_logging.log("Checking Numerical Differences between train logits and golden logits")
assert jax.numpy.allclose(
full_train_logits[0, :token_size, :],
full_train_logits[0, 0, :token_size, :],
golden_logits[:token_size, :],
rtol=float(test_args.rtol),
atol=float(test_args.atol),
Expand Down

0 comments on commit 5ec38ee

Please sign in to comment.