@@ -1087,12 +1087,14 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
10871087#endif
10881088}
10891089
1090- std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void ()>>>
1090+ std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void ()>>>
10911091Buffer::low_latency_dispatch (const torch::Tensor& x, const torch::Tensor& topk_idx,
10921092 const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
10931093 const std::optional<torch::Tensor>& dispatch_wait_recv_cost_stats,
1094+ const std::optional<torch::Tensor>& x_sf_scale,
10941095 int num_max_dispatch_tokens_per_rank, int num_experts,
10951096 bool use_fp8, bool round_scale, bool use_ue8m0,
1097+ bool use_nvfp4, bool use_ue8m0_for_nvfp4_sf,
10961098 bool async, bool return_recv_hook) {
10971099#ifndef DISABLE_NVSHMEM
10981100 EP_HOST_ASSERT (low_latency_mode);
@@ -1136,15 +1138,18 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
11361138 stream_wait (launch_stream, compute_stream);
11371139
11381140 // Allocate packed tensors
1139- auto packed_recv_x = torch::empty ({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden},
1140- x.options ().dtype (use_fp8 ? torch::kFloat8_e4m3fn : torch::kBFloat16 ));
1141+ constexpr int NUM_ELEMS_PER_PACK = 8 ;
1142+ auto packed_recv_x = torch::empty ({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, use_nvfp4 ? hidden / NUM_ELEMS_PER_PACK : hidden},
1143+ x.options ().dtype (use_nvfp4 ? torch::kInt32 : (use_fp8 ? torch::kFloat8_e4m3fn : torch::kBFloat16 )));
11411144 auto packed_recv_src_info = torch::empty ({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype (torch::kInt32 ).device (torch::kCUDA ));
11421145 auto packed_recv_layout_range = torch::empty ({num_local_experts, num_ranks}, torch::dtype (torch::kInt64 ).device (torch::kCUDA ));
11431146 auto packed_recv_count = torch::empty ({num_local_experts}, torch::dtype (torch::kInt32 ).device (torch::kCUDA ));
11441147
11451148 // Allocate column-majored scales
11461149 auto packed_recv_x_scales = std::optional<torch::Tensor>();
11471150 void * packed_recv_x_scales_ptr = nullptr ;
1151+ auto packed_recv_x_sf_scale = std::optional<torch::Tensor>();
1152+ void * packed_recv_x_sf_scale_ptr = nullptr ;
11481153 EP_HOST_ASSERT ((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and " TMA requires the number of tokens to be multiple of 4" );
11491154
11501155 if (use_fp8) {
@@ -1160,23 +1165,34 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
11601165 }
11611166 packed_recv_x_scales = torch::transpose (packed_recv_x_scales.value (), 1 , 2 );
11621167 packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr ();
1168+ }else if (use_nvfp4) {
1169+ constexpr int SF_VEC_SIZE = 16 ;
1170+ constexpr int NUM_SF_ELEMS_PER_PACK = 4 ;
1171+ packed_recv_x_scales = torch::empty ({num_local_experts, hidden / (SF_VEC_SIZE * NUM_SCALE_ELEMS_PER_PACK), num_ranks * num_max_dispatch_tokens_per_rank},
1172+ torch::dtype (torch::kInt ).device (torch::kCUDA ));
1173+ packed_recv_x_scales = torch::transpose (packed_recv_x_scales.value (), 1 , 2 );
1174+ packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr ();
1175+ packed_recv_x_sf_scale = torch::empty ({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype (torch::kFloat32 ).device (torch::kCUDA ));
1176+ packed_recv_x_sf_scale_ptr = packed_recv_x_sf_scale->data_ptr ();
11631177 }
11641178
11651179 // Kernel launch
11661180 auto next_clean_meta = next_buffer.clean_meta ();
11671181 auto launcher = [=](int phases) {
1168- internode_ll::dispatch (packed_recv_x.data_ptr (), packed_recv_x_scales_ptr,
1182+ internode_ll::dispatch (packed_recv_x.data_ptr (), packed_recv_x_scales_ptr, packed_recv_x_sf_scale_ptr,
11691183 packed_recv_src_info.data_ptr <int >(), packed_recv_layout_range.data_ptr <int64_t >(),
11701184 packed_recv_count.data_ptr <int >(),
11711185 cumulative_local_expert_recv_stats.has_value () ? cumulative_local_expert_recv_stats->data_ptr <int >() : nullptr ,
11721186 dispatch_wait_recv_cost_stats.has_value () ? dispatch_wait_recv_cost_stats->data_ptr <int64_t >() : nullptr ,
1187+ x_sf_scale.has_value () ? x_sf_scale->data_ptr <float >() : nullptr ,
11731188 buffer.dispatch_rdma_recv_data_buffer , buffer.dispatch_rdma_recv_count_buffer ,
11741189 buffer.dispatch_rdma_send_buffer ,
11751190 x.data_ptr (), topk_idx.data_ptr <int64_t >(),
11761191 next_clean_meta.first , next_clean_meta.second ,
11771192 num_tokens, hidden, num_max_dispatch_tokens_per_rank,
11781193 num_topk, num_experts, rank, num_ranks,
11791194 use_fp8, round_scale, use_ue8m0,
1195+ use_nvfp4, use_ue8m0_for_nvfp4_sf,
11801196 workspace, num_device_sms,
11811197 launch_stream, phases);
11821198 };
@@ -1198,7 +1214,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
11981214 recv_hook = [=]() { launcher (LOW_LATENCY_RECV_PHASE); };
11991215
12001216 // Return values
1201- return {packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, recv_hook};
1217+ return {packed_recv_x, packed_recv_x_scales, packed_recv_x_sf_scale, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, recv_hook};
12021218#else
12031219 EP_HOST_ASSERT (false and " NVSHMEM is disabled during compilation" );
12041220 return {};
0 commit comments