Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix MoE related tests #1064

Merged
merged 1 commit into from
Nov 26, 2024
Merged

Fix MoE related tests #1064

merged 1 commit into from
Nov 26, 2024

Conversation

RissyRan
Copy link
Collaborator

@RissyRan RissyRan commented Nov 26, 2024

Description

  • Add weights normalization and update ckpt to unscanned version for faster decoding
  • Disabled ckpt related tests for 8x22b until b/380148614 has been fixed (otherwise, with wrong ckpt, it always fails in XL ML)

If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/379315309

Tests

  • test locally
    ** 8x7b decoding: link
    ** 8x7b forward_pass_logit_checker: link, link

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

@RissyRan RissyRan force-pushed the fix_8x22_ckpt branch 4 times, most recently from 06e75bb to ac86c96 Compare November 26, 2024 08:16

# Test whether the forward pass logits match the golden logits - megablox implementation
python3 MaxText/tests/forward_pass_logit_checker.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${SCANNED_CHECKPOINT} run_name=megablox_forward_pass_test per_device_batch_size=8 model_name=mixtral-8x7b tokenizer_path=assets/tokenizer.mistral-v1 ici_fsdp_parallelism=64 max_prefill_predict_length=11 max_target_length=11 dataset_type=synthetic dtype=bfloat16 weight_dtype=bfloat16 --atol=4 --rtol=1 --token_size=4
# TODO(ranran): investigate the root cause of the excessive tolerance
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In b/380148614 the root cause of excessive tolerance?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure actually. Using current ckpt, matmul implementation could match the logits well. Something seems off, as we could match logits a while back.

@copybara-service copybara-service bot merged commit 135575f into main Nov 26, 2024
19 checks passed
@copybara-service copybara-service bot deleted the fix_8x22_ckpt branch November 26, 2024 21:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants