Skip to content

Conversation

@amirkl94
Copy link
Contributor

@amirkl94 amirkl94 commented Oct 21, 2025

Purpose

When running cutlass FusedMoE FP8 the scaling factors that are passed are None. This PR passes the correct scaling factors and enables the relevant test.

Test Plan

Enabled previously disabled test_flashinfer_cutlass_moe_fp8_no_graph .

When running cutlass FusedMoE FP8 the scaling factors that are passed are None.
This PR passes the correct scaling factors and enables the relevant test.

Signed-off-by: Amir Klein <[email protected]>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly addresses a bug in the Cutlass FP8 FusedMoE implementation by passing the necessary scaling factors. The changes are logical and enabling the previously skipped test test_flashinfer_cutlass_moe_fp8_no_graph validates the fix. However, I've identified a critical risk of a division-by-zero error in the calculation of the a2_gscale factor, which should be addressed to ensure numerical stability.

a1_scale=layer.w13_input_scale,
a1_gscale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
a2_gscale=1.0 / layer.w2_input_scale,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The calculation 1.0 / layer.w2_input_scale introduces a risk of a division-by-zero error if layer.w2_input_scale is zero. Although scales are typically positive, adding a small epsilon to the denominator is a crucial safeguard for numerical stability.

Suggested change
a2_gscale=1.0 / layer.w2_input_scale,
a2_gscale=1.0 / (layer.w2_input_scale + 1e-6),

a1_scale=td.a1_scale,
a1_gscale=td.a1_scale,
a2_scale=td.a2_scale,
a2_gscale=1.0 / td.a2_scale,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

To prevent potential division-by-zero errors and for consistency with the recommended fix in the main logic, it's safer to add a small epsilon to the denominator here. While td.a2_scale is currently 1.0 in this test, this change improves the robustness of the test suite against future modifications.

Suggested change
a2_gscale=1.0 / td.a2_scale,
a2_gscale=1.0 / (td.a2_scale + 1e-6),

@mgoin mgoin added bug Something isn't working ready ONLY add when PR is ready to merge/full CI is needed labels Oct 24, 2025
@amirkl94 amirkl94 requested a review from tlrmchlsmth November 2, 2025 09:38
@amirkl94
Copy link
Contributor Author

amirkl94 commented Nov 2, 2025

@tlrmchlsmth Looks like CI failed due to something unrelated to my PR

Signed-off-by: Amir Klein <[email protected]>
@amirkl94
Copy link
Contributor Author

amirkl94 commented Nov 2, 2025

@wenscarl This reverts a change you made in #27223 can you please take a look?

@mgoin mgoin changed the title Bugfix: Cutlass FP8 FusedMoE Bugfix: Cutlass FP8 FusedMoE bad scaling factors Nov 4, 2025
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, please validate @wenscarl

a1, topk_weights, topk_ids, apply_router_weight_on_input
)
if not self.use_dp:
if not self.use_dp and quant_config.quant_dtype != torch.float8_e4m3fn:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest using a more descriptive condition, for example:

 if not self.use_dp and quant_config.quant_dtype == nvfp4:
     return a1, None, ...,  

per_act_token_quant: bool = False,
per_out_ch_quant: bool = False,
block_shape: list[int] | None = None,
a1_gscale: torch.Tensor | None = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The additional 4 scales were not present before 27223 and no issue without them. Is it possible to deduce them from others?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I could tell, no

Signed-off-by: Amir Klein <[email protected]>
@amirkl94 amirkl94 requested a review from wenscarl November 4, 2025 15:31

return fp8_w8a8_moe_quant_config(
w1_scale=layer.w13_weight_scale,
g1_alphas=(layer.w13_weight_scale * layer.w13_input_scale).squeeze(),
Copy link
Contributor

@wenscarl wenscarl Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The g1_alphas can be computed by w1_scale and a1_scale here, right? Same in test_flashinfer.py.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather not move this calculation inside the function as I'm not sure if in the future other paths will require different factors. If it's important then sure I can move it to be calculated in fp8_w8a8_moe_quant_config .
@wenscarl wdyt?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sound reasonable. Let's leave it as it is.

@mgoin mgoin merged commit 6b7a811 into vllm-project:main Nov 5, 2025
56 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants