Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MFU drops significantly when using megablox with more experts #1256

Open
rodrigo-f-nogueira opened this issue Feb 9, 2025 · 4 comments
Open
Assignees

Comments

@rodrigo-f-nogueira
Copy link

I'm testing Mixtral-8x7B without attention so I can isolate the effects of the MoE layer.

When num_experts=8 and num_experts_per_token=2, MFU on a v5p-64 is 50.4%, which is good.

However, I wanted to test an architecture that is more similar to DeepSeek's, which uses more experts.

Thus, I increased the number of experts from 8 to 56 (7x increase), the number of experts per token from 2 to 14 (7x increase), and decreased the moe_intermediate_size from 14336 to 2048 (7x decrease). Thus, I'm still using the same total and active number of parameters of Mixtral.

The problem is that in this new architecture with more experts, the MFU drops to 28%!

(BTW, I tried multiple configs with different tile_sizes, TPU sizes and batch sizes, all leading to 25-28% MFU)

Any help is much appreciated.

(cc'ing @sharadmv @RissyRan @lenscloth who might be interested in this problem)

@RissyRan
Copy link
Collaborator

Thanks for reaching out! It seems you have tuned a little bit on this general tile size (here), but I'd like to mention this size could be very different based on TPU type and topology (sizes), and model config. So with fixed model config, you could have a script to find the best tile_sizes. When using FSDP sharding strategy (default settings), large batch sizes will definitely help improve the performance. The next step could turn on this profiling option, and see which operation (or extra communication) slow down the test. See more details here about JAX profiling.

@rodrigo-f-nogueira
Copy link
Author

Hi @RissyRan, sorry for taking so long...

jnp.take is the slowest operation when using 56 experts (with 14 active per token and mlp_dim=2048):
https://github.com/AI-Hypercomputer/maxtext/blob/main/MaxText/layers/linears.py#L404

Here are the top 10 operations, in case you are curious:
Image

As a reference, these are the top 10 operations when using 8 experts (with 2 active per token and mlp_dim=14336), which gives higher MFU (30% instead of 20% on a v4-64):

Image

@RissyRan
Copy link
Collaborator

Thanks for the info! Yes, ideally, we should see pallas_call as top operations. Our team is working DeepSeek-like model config, and have onboarded some functional features recently. We are also working on optimizing the performance in the following few weeks. I will reply back once we have some benchmarks if this is ok.

@rodrigo-f-nogueira
Copy link
Author

Awesome, thank you very much for your great work!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants