Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update forward_pass_logit_checker for new dimension #1175

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

A9isha
Copy link
Collaborator

@A9isha A9isha commented Jan 17, 2025

Description

There is an added dimension for jax.experimental.multihost_utils.process_allgather(full_train_logits), so updating forward_pass_logit_checker for new dimension

FIXES: b/389749060

Tests

Locally reran forward_pass_logit_checker for mistral

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

Copy link
Collaborator

@gobbleturk gobbleturk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome thanks Anisha!

Don't we have to check for jax version, or does this change support both old and new? Will the tests targeting jax stable break with this change?
E.g. in the past sometimes we have had code like

if jax.version >= X
    do_new_thing
else:
    do_old_thing

Is that needed here?

@A9isha A9isha force-pushed the anisha-mistral-fixit branch from 5ec38ee to 1899c69 Compare January 17, 2025 17:39
@A9isha
Copy link
Collaborator Author

A9isha commented Jan 17, 2025

Thanks Matt. I thought further comparing the tests for Mistral and Llama2/Gemma2 and realized Jax actually handles dimension of value 1, so this updated way is a cleaner solution

Awesome thanks Anisha!

Don't we have to check for jax version, or does this change support both old and new? Will the tests targeting jax stable break with this change? E.g. in the past sometimes we have had code like

if jax.version >= X
    do_new_thing
else:
    do_old_thing

Is that needed here?

@A9isha A9isha force-pushed the anisha-mistral-fixit branch 2 times, most recently from 863ca64 to f7d7e76 Compare January 17, 2025 18:52
@A9isha
Copy link
Collaborator Author

A9isha commented Jan 17, 2025

Jax actually handles dimension of value 1, so this updated way is a cleaner solution

@AlipMagical Could you please give a bit more context?

@AI-Hypercomputer AI-Hypercomputer deleted a comment from AlipMagical Jan 17, 2025
@AI-Hypercomputer AI-Hypercomputer deleted a comment from AlipMagical Jan 17, 2025
@AI-Hypercomputer AI-Hypercomputer deleted a comment from AlipMagical Jan 17, 2025
@AI-Hypercomputer AI-Hypercomputer deleted a comment from AlipMagical Jan 17, 2025
@AI-Hypercomputer AI-Hypercomputer deleted a comment from AlipMagical Jan 17, 2025
@AI-Hypercomputer AI-Hypercomputer deleted a comment from AlipMagical Jan 17, 2025
@A9isha A9isha force-pushed the anisha-mistral-fixit branch from e6c4d5f to 6a6b270 Compare January 17, 2025 22:11
copybara-service bot pushed a commit that referenced this pull request Jan 17, 2025
--
f7d7e76 by A9isha <[email protected]>:

update forward_pass_logit_checker for new dimension

COPYBARA_INTEGRATE_REVIEW=#1175 from AI-Hypercomputer:anisha-mistral-fixit f7d7e76
PiperOrigin-RevId: 716734551
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants