@@ -200,7 +200,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
200200 int * packed_recv_count,
201201 int * cumulative_local_expert_recv_stats,
202202 int64_t * dispatch_wait_recv_cost_stats,
203- const float * x_global_scales ,
203+ const float * x_sf_scale ,
204204 void * rdma_recv_x, int * rdma_recv_count, void * rdma_x,
205205 const void * x, const int64_t * topk_idx,
206206 int * atomic_counter_per_expert, int * atomic_finish_counter_per_expert,
@@ -275,8 +275,8 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
275275 float SFScaleVal = 1 .0f ;
276276 if constexpr (kUseNVFP4 ) {
277277 // Get scaling value;
278- EP_DEVICE_ASSERT (x_global_scales != nullptr );
279- SFScaleVal = *(static_cast <const float *>(x_global_scales ));
278+ EP_DEVICE_ASSERT (x_sf_scale != nullptr );
279+ SFScaleVal = *(static_cast <const float *>(x_sf_scale ));
280280 }
281281
282282 // FP8 or NVFP4 cast
@@ -517,21 +517,28 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
517517 recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;
518518 }
519519 } else if constexpr (kUseNVFP4 ) {
520- // The physical layout is (l, rm, rk, 32, 4, 4).
520+ // The physical layout is (l, rm, rk, 32, 4, 4)
521521 const auto src_scales = reinterpret_cast <uint8_t *>(reinterpret_cast <uint8_t *>(src_data) + hidden_bytes);
522522 const auto num_elems_per_pack = static_cast <int >(sizeof (packed_t ) / sizeof (scale_t ));
523523 const auto token_idx = recv_token_begin_idx + i;
524- const auto token_stride = num_scales * sizeof (scale_t );
525- const auto pack_stride = num_elems_per_pack;
526- const auto rm = token_idx / 128 ;
527- const auto rm_res = token_idx % 128 ;
524+
525+ const auto padded_k = (kHidden + (kNumPerChannels * num_elems_per_pack) -1 ) / (kNumPerChannels * num_elems_per_pack);
526+ const auto dim0_stride = 128 * padded_k / kNumPerChannels ;
527+ const auto dim1_stride = 128 * num_elems_per_pack;
528+ const auto dim2_stride = 4 * num_elems_per_pack;
529+ const auto dim3_stride = num_elems_per_pack;
530+
531+ const auto dim0_offset = token_idx / 128 ;
532+ const auto dim2_offset = (token_idx % 128 ) % 32 ;
533+ const auto dim3_offset = (token_idx % 128 ) / 32 ;
534+
528535 #pragma unroll
529536 for (int j = lane_id; j < num_scales; j += 32 ) {
530- const auto pack_idx = j / num_elems_per_pack;
531- const auto elem_idx = j % num_elems_per_pack;
537+ const auto dim1_offset = j / num_elems_per_pack;
538+ const auto dim4_offset = j % num_elems_per_pack;
532539 auto scale = ld_nc_global (src_scales + j);
533- // recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;
534- recv_x_scales[rm * token_stride * 128 + pack_idx * pack_stride * 128 + rm_res * pack_stride + elem_idx ] = scale;
540+ const auto offset = dim0_offset * dim0_stride + dim1_offset * dim1_stride + dim2_offset * dim2_stride + dim3_offset * dim3_stride + dim4_offset;
541+ recv_x_scales[offset ] = scale;
535542 }
536543 }
537544 }
@@ -543,14 +550,14 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
543550 int * packed_recv_count,
544551 int * cumulative_local_expert_recv_stats,
545552 int64_t * dispatch_wait_recv_cost_stats,
546- const float * x_global_scales ,
553+ const float * x_sf_scale ,
547554 void * rdma_recv_x, int * rdma_recv_count, void * rdma_x,
548555 const void * x, const int64_t * topk_idx,
549556 int * next_clean, int num_next_clean_int,
550557 int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
551558 int num_topk, int num_experts, int rank, int num_ranks,
552559 bool use_fp8, bool round_scale, bool use_ue8m0,
553- bool use_nvfp4, bool use_ue8m0_for_nvfp4_x_scale ,
560+ bool use_nvfp4, bool use_ue8m0_for_sf ,
554561 void * workspace, int num_device_sms,
555562 cudaStream_t stream, int phases) {
556563 constexpr int kNumMaxTopK = 9 ;
@@ -578,17 +585,17 @@ if (use_fp8 and not use_ue8m0) \
578585 dispatch_func = dispatch<true , false , false , false , hidden>; \
579586if (use_fp8 and use_ue8m0) \
580587 dispatch_func = dispatch<true , true , false , false , hidden>; \
581- if (use_nvfp4 and not use_ue8m0_for_nvfp4_x_scale ) \
588+ if (use_nvfp4 and not use_ue8m0_for_sf ) \
582589 dispatch_func = dispatch<false , false , true , false , hidden>; \
583- if (use_nvfp4 and use_ue8m0_for_nvfp4_x_scale ) \
590+ if (use_nvfp4 and use_ue8m0_for_sf ) \
584591 dispatch_func = dispatch<false , false , true , true , hidden>; \
585592LAUNCH_KERNEL (&cfg, dispatch_func, \
586593 packed_recv_x, packed_recv_x_scales, \
587594 packed_recv_src_info, packed_recv_layout_range, \
588595 packed_recv_count, \
589596 cumulative_local_expert_recv_stats, \
590597 dispatch_wait_recv_cost_stats, \
591- x_global_scales , \
598+ x_sf_scale , \
592599 rdma_recv_x, rdma_recv_count, rdma_x, \
593600 x, topk_idx, \
594601 atomic_counter_per_expert, atomic_finish_counter_per_expert, \
0 commit comments