-
Notifications
You must be signed in to change notification settings - Fork 946
Support nvfp4 low latency mode dispatch #341
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: hybrid-ep
Are you sure you want to change the base?
Support nvfp4 low latency mode dispatch #341
Conversation
fe83c6c
to
0a7f43e
Compare
0a7f43e
to
5cd59de
Compare
@shifangx - can you explain how to build this from source? |
585144f
to
74b631a
Compare
Hello, @ishandhanani, thank you for your attention to our work. |
c358fd5
to
1be895a
Compare
May i ask the quant_method of fp4 model,it seems that you use 16 elements as group instead of 128 to reduce the accuracy loss, but i still wonder the quant method and how its performance compared with orginal fp8 model? Besides, may i ask the computation type of the following gemm, it seems that activation, weight is fp4, in blackwell, the result might be fp32 accumulated, how it works with 8 bit scale? does it have possibility work in hopper [fp4 dequant to fp8 might cause some scale transform] ? if you could share the fp4 gemm application in hopper and blackwell, it will be great help, thanks. |
@DoubleClark, Hi, if you are interested in FP4 training, perhaps these blogs can provide some help.
|
dad206a
to
0cfe452
Compare
cb1757a
to
9d9e395
Compare
bf1f716
to
5deac0f
Compare
change from x_sf_scale to x_global_scales. change from use_ue8m0_for_sf to use_ue8m0_for_nvfp4_x_scale. set x_scale dtpye as torch::kFloat8_e4m3fn for if use_ue8m0_for_nvfp4_x_scale==False and torch::kUInt8 for use_ue8m0_for_nvfp4_x_scale==True.
63ad6b4
to
8cc65fd
Compare
8cc65fd
to
d89a25b
Compare
const auto dim1_offset = j / num_elems_per_pack; | ||
const auto dim4_offset = j % num_elems_per_pack; | ||
auto scale = ld_nc_global(src_scales + j); | ||
const auto offset = dim0_offset * dim0_stride + dim1_offset * dim1_stride + dim2_offset * dim2_stride + dim3_offset * dim3_stride + dim4_offset; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
qq: looks like the physical layout is 6D, thus curious why we only have 5 dim here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your kindly review.
recv_x_scales[offset] = scale;
recv_x_scales is only for one expert, so its layout is 5D.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh i see, looks reasonable
For the fp4 quantize, this PR refer to cvt_warp_fp16_to_fp4 in For the scale layout and shape, this PR refer to test_quantize_to_fp4_grouped |
accuracy issue is fixed now |
@shifangx Anything blocking this merge? |
This MR support nvfp4 low latency mode dispatch.

We use the following message package format while dispatching tokens.