-
-
Notifications
You must be signed in to change notification settings - Fork 11k
Bugfix: Cutlass FP8 FusedMoE bad scaling factors #27255
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
When running cutlass FusedMoE FP8 the scaling factors that are passed are None. This PR passes the correct scaling factors and enables the relevant test. Signed-off-by: Amir Klein <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request correctly addresses a bug in the Cutlass FP8 FusedMoE implementation by passing the necessary scaling factors. The changes are logical and enabling the previously skipped test test_flashinfer_cutlass_moe_fp8_no_graph validates the fix. However, I've identified a critical risk of a division-by-zero error in the calculation of the a2_gscale factor, which should be addressed to ensure numerical stability.
| a1_scale=layer.w13_input_scale, | ||
| a1_gscale=layer.w13_input_scale, | ||
| a2_scale=layer.w2_input_scale, | ||
| a2_gscale=1.0 / layer.w2_input_scale, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The calculation 1.0 / layer.w2_input_scale introduces a risk of a division-by-zero error if layer.w2_input_scale is zero. Although scales are typically positive, adding a small epsilon to the denominator is a crucial safeguard for numerical stability.
| a2_gscale=1.0 / layer.w2_input_scale, | |
| a2_gscale=1.0 / (layer.w2_input_scale + 1e-6), |
| a1_scale=td.a1_scale, | ||
| a1_gscale=td.a1_scale, | ||
| a2_scale=td.a2_scale, | ||
| a2_gscale=1.0 / td.a2_scale, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To prevent potential division-by-zero errors and for consistency with the recommended fix in the main logic, it's safer to add a small epsilon to the denominator here. While td.a2_scale is currently 1.0 in this test, this change improves the robustness of the test suite against future modifications.
| a2_gscale=1.0 / td.a2_scale, | |
| a2_gscale=1.0 / (td.a2_scale + 1e-6), |
Signed-off-by: Amir Klein <[email protected]>
|
@tlrmchlsmth Looks like CI failed due to something unrelated to my PR |
Signed-off-by: Amir Klein <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, please validate @wenscarl
| a1, topk_weights, topk_ids, apply_router_weight_on_input | ||
| ) | ||
| if not self.use_dp: | ||
| if not self.use_dp and quant_config.quant_dtype != torch.float8_e4m3fn: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest using a more descriptive condition, for example:
if not self.use_dp and quant_config.quant_dtype == nvfp4:
return a1, None, ...,
| per_act_token_quant: bool = False, | ||
| per_out_ch_quant: bool = False, | ||
| block_shape: list[int] | None = None, | ||
| a1_gscale: torch.Tensor | None = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The additional 4 scales were not present before 27223 and no issue without them. Is it possible to deduce them from others?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I could tell, no
Signed-off-by: Amir Klein <[email protected]>
|
|
||
| return fp8_w8a8_moe_quant_config( | ||
| w1_scale=layer.w13_weight_scale, | ||
| g1_alphas=(layer.w13_weight_scale * layer.w13_input_scale).squeeze(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The g1_alphas can be computed by w1_scale and a1_scale here, right? Same in test_flashinfer.py.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd rather not move this calculation inside the function as I'm not sure if in the future other paths will require different factors. If it's important then sure I can move it to be calculated in fp8_w8a8_moe_quant_config .
@wenscarl wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sound reasonable. Let's leave it as it is.
Purpose
When running cutlass FusedMoE FP8 the scaling factors that are passed are None. This PR passes the correct scaling factors and enables the relevant test.
Test Plan
Enabled previously disabled
test_flashinfer_cutlass_moe_fp8_no_graph.