-
Notifications
You must be signed in to change notification settings - Fork 321
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MFU drops significantly when using megablox with more experts #1256
Comments
Thanks for reaching out! It seems you have tuned a little bit on this general tile size (here), but I'd like to mention this size could be very different based on TPU type and topology (sizes), and model config. So with fixed model config, you could have a script to find the best tile_sizes. When using FSDP sharding strategy (default settings), large batch sizes will definitely help improve the performance. The next step could turn on this profiling option, and see which operation (or extra communication) slow down the test. See more details here about JAX profiling. |
Hi @RissyRan, sorry for taking so long... jnp.take is the slowest operation when using 56 experts (with 14 active per token and mlp_dim=2048): Here are the top 10 operations, in case you are curious: As a reference, these are the top 10 operations when using 8 experts (with 2 active per token and mlp_dim=14336), which gives higher MFU (30% instead of 20% on a v4-64): |
Thanks for the info! Yes, ideally, we should see pallas_call as top operations. Our team is working DeepSeek-like model config, and have onboarded some functional features recently. We are also working on optimizing the performance in the following few weeks. I will reply back once we have some benchmarks if this is ok. |
Awesome, thank you very much for your great work! |
I'm testing Mixtral-8x7B without attention so I can isolate the effects of the MoE layer.
When num_experts=8 and num_experts_per_token=2, MFU on a v5p-64 is 50.4%, which is good.
However, I wanted to test an architecture that is more similar to DeepSeek's, which uses more experts.
Thus, I increased the number of experts from 8 to 56 (7x increase), the number of experts per token from 2 to 14 (7x increase), and decreased the moe_intermediate_size from 14336 to 2048 (7x decrease). Thus, I'm still using the same total and active number of parameters of Mixtral.
The problem is that in this new architecture with more experts, the MFU drops to 28%!
(BTW, I tried multiple configs with different tile_sizes, TPU sizes and batch sizes, all leading to 25-28% MFU)
Any help is much appreciated.
(cc'ing @sharadmv @RissyRan @lenscloth who might be interested in this problem)
The text was updated successfully, but these errors were encountered: