-
-
Notifications
You must be signed in to change notification settings - Fork 11.1k
Bugfix: Cutlass FP8 FusedMoE bad scaling factors #27255
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5bcd35d
c45fbe4
4925add
7277cd4
935fb9e
d7932ce
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -463,16 +463,24 @@ def fp8_w8a8_moe_quant_config( | |
| per_act_token_quant: bool = False, | ||
| per_out_ch_quant: bool = False, | ||
| block_shape: list[int] | None = None, | ||
| a1_gscale: torch.Tensor | None = None, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The additional 4 scales were not present before 27223 and no issue without them. Is it possible to deduce them from others?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As far as I could tell, no |
||
| a2_gscale: torch.Tensor | None = None, | ||
| g1_alphas: torch.Tensor | None = None, | ||
| g2_alphas: torch.Tensor | None = None, | ||
| ) -> FusedMoEQuantConfig: | ||
| """ | ||
| Construct a quant config for fp8 activations and fp8 weights. | ||
| """ | ||
| return FusedMoEQuantConfig.make( | ||
| torch.float8_e4m3fn, | ||
| w1_scale=w1_scale, | ||
| g1_alphas=g1_alphas, | ||
| w2_scale=w2_scale, | ||
| g2_alphas=g2_alphas, | ||
| a1_scale=a1_scale, | ||
| a1_gscale=a1_gscale, | ||
| a2_scale=a2_scale, | ||
| a2_gscale=a2_gscale, | ||
| per_act_token_quant=per_act_token_quant, | ||
| per_out_ch_quant=per_out_ch_quant, | ||
| block_shape=block_shape, | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -570,9 +570,13 @@ def get_fused_moe_quant_config( | |||||
|
|
||||||
| return fp8_w8a8_moe_quant_config( | ||||||
| w1_scale=layer.w13_weight_scale, | ||||||
| g1_alphas=(layer.w13_weight_scale * layer.w13_input_scale).squeeze(), | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The g1_alphas can be computed by w1_scale and a1_scale here, right? Same in
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd rather not move this calculation inside the function as I'm not sure if in the future other paths will require different factors. If it's important then sure I can move it to be calculated in
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sound reasonable. Let's leave it as it is. |
||||||
| w2_scale=layer.w2_weight_scale, | ||||||
| g2_alphas=(layer.w2_weight_scale * layer.w2_input_scale).squeeze(), | ||||||
| a1_scale=layer.w13_input_scale, | ||||||
| a1_gscale=layer.w13_input_scale, | ||||||
| a2_scale=layer.w2_input_scale, | ||||||
| a2_gscale=1.0 / layer.w2_input_scale, | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The calculation
Suggested change
|
||||||
| per_act_token_quant=False, | ||||||
| ) | ||||||
|
|
||||||
|
|
@@ -1159,8 +1163,8 @@ def __init__( | |||||
| moe: FusedMoEConfig, | ||||||
| layer: torch.nn.Module, | ||||||
| ) -> None: | ||||||
| from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501 | ||||||
| detect_nvfp4_moe_support, | ||||||
| from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( | ||||||
| detect_nvfp4_moe_support, # noqa: E501 | ||||||
| ) | ||||||
|
|
||||||
| super().__init__(moe) | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To prevent potential division-by-zero errors and for consistency with the recommended fix in the main logic, it's safer to add a small epsilon to the denominator here. While
td.a2_scaleis currently 1.0 in this test, this change improves the robustness of the test suite against future modifications.