-
Notifications
You must be signed in to change notification settings - Fork 310
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update forward_pass_logit_checker for new dimension #1175
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome thanks Anisha!
Don't we have to check for jax version, or does this change support both old and new? Will the tests targeting jax stable break with this change?
E.g. in the past sometimes we have had code like
if jax.version >= X
do_new_thing
else:
do_old_thing
Is that needed here?
5ec38ee
to
1899c69
Compare
Thanks Matt. I thought further comparing the tests for Mistral and Llama2/Gemma2 and realized Jax actually handles dimension of value 1, so this updated way is a cleaner solution
|
863ca64
to
f7d7e76
Compare
@AlipMagical Could you please give a bit more context? |
e6c4d5f
to
6a6b270
Compare
-- f7d7e76 by A9isha <[email protected]>: update forward_pass_logit_checker for new dimension COPYBARA_INTEGRATE_REVIEW=#1175 from AI-Hypercomputer:anisha-mistral-fixit f7d7e76 PiperOrigin-RevId: 716734551
Description
There is an added dimension for
jax.experimental.multihost_utils.process_allgather(full_train_logits)
, so updatingforward_pass_logit_checker
for new dimensionFIXES: b/389749060
Tests
Locally reran forward_pass_logit_checker for mistral
Checklist
Before submitting this PR, please make sure (put X in square brackets):