@@ -1122,8 +1122,10 @@ std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Te
11221122Buffer::low_latency_dispatch (const torch::Tensor& x, const torch::Tensor& topk_idx,
11231123 const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
11241124 const std::optional<torch::Tensor>& dispatch_wait_recv_cost_stats,
1125+ const std::optional<torch::Tensor>& x_global_scale,
11251126 int num_max_dispatch_tokens_per_rank, int num_experts,
11261127 bool use_fp8, bool round_scale, bool use_ue8m0,
1128+ bool use_nvfp4, bool use_ue8m0_for_sf,
11271129 bool async, bool return_recv_hook) {
11281130#ifndef DISABLE_NVSHMEM
11291131 EP_HOST_ASSERT (low_latency_mode);
@@ -1168,8 +1170,8 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
11681170 stream_wait (launch_stream, compute_stream);
11691171
11701172 // Allocate packed tensors
1171- auto packed_recv_x = torch::empty ({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden},
1172- x.options ().dtype (use_fp8 ? torch::kFloat8_e4m3fn : torch::kBFloat16 ));
1173+ auto packed_recv_x = torch::empty ({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, use_nvfp4 ? hidden / 2 : hidden},
1174+ x.options ().dtype (use_nvfp4 ? torch:: kUInt8 : ( use_fp8 ? torch::kFloat8_e4m3fn : torch::kBFloat16 ) ));
11731175 auto packed_recv_src_info = torch::empty ({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype (torch::kInt32 ).device (torch::kCUDA ));
11741176 auto packed_recv_layout_range = torch::empty ({num_local_experts, num_ranks}, torch::dtype (torch::kInt64 ).device (torch::kCUDA ));
11751177 auto packed_recv_count = torch::empty ({num_local_experts}, torch::dtype (torch::kInt32 ).device (torch::kCUDA ));
@@ -1179,6 +1181,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
11791181 void * packed_recv_x_scales_ptr = nullptr ;
11801182 EP_HOST_ASSERT ((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and " TMA requires the number of tokens to be multiple of 4" );
11811183
1184+ EP_HOST_ASSERT (not (use_fp8 and use_nvfp4));
11821185 if (use_fp8) {
11831186 // TODO: support unaligned cases
11841187 EP_HOST_ASSERT (hidden % 512 == 0 );
@@ -1192,6 +1195,35 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
11921195 }
11931196 packed_recv_x_scales = torch::transpose (packed_recv_x_scales.value (), 1 , 2 );
11941197 packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr ();
1198+ }else if (use_nvfp4) {
1199+ constexpr int kNumPerChannels = 16 ;
1200+ constexpr int NUM_SF_ELEMS_PER_PACK = 4 ;
1201+ constexpr int mTileSize_dim_0 = 32 ;
1202+ constexpr int mTileSize_dim_1 = 4 ;
1203+ constexpr int mTileSize = mTileSize_dim_0 * mTileSize_dim_1 ;
1204+
1205+ assert (hidden % kNumPerChannels == 0 );
1206+ auto l = num_local_experts;
1207+ auto m = num_ranks * num_max_dispatch_tokens_per_rank;
1208+ auto rm = (m + 127 ) / 128 ;
1209+ auto rk = (hidden + (kNumPerChannels * NUM_SF_ELEMS_PER_PACK) -1 ) / (kNumPerChannels * NUM_SF_ELEMS_PER_PACK);
1210+ // The physical layout is (l, rm, rk, 32, 4, 4).
1211+ if (use_ue8m0_for_sf) {
1212+ packed_recv_x_scales = torch::empty ({l, rm, rk, 32 , 4 , 4 },
1213+ torch::dtype (torch::kInt ).device (torch::kCUDA ));
1214+ } else {
1215+ packed_recv_x_scales = torch::empty ({l, rm, rk, 32 , 4 , 4 },
1216+ torch::dtype (torch::kFloat8_e4m3fn ).device (torch::kCUDA ));
1217+ }
1218+ // After permute, the logical shape is (32, 4, rm, 4, rk, l)
1219+ packed_recv_x_scales = packed_recv_x_scales.value ().permute ({3 , 4 , 1 , 5 , 2 , 0 });
1220+
1221+ // The physical layout is (l, m, k // 2).
1222+ // After permute, the logical shape is (m, k // 2, l).
1223+ packed_recv_x = packed_recv_x.permute ({1 , 2 , 0 });
1224+
1225+ packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr ();
1226+ EP_HOST_ASSERT (packed_recv_x_scales_ptr != nullptr );
11951227 }
11961228
11971229 // Kernel launch
@@ -1202,13 +1234,15 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
12021234 packed_recv_count.data_ptr <int >(),
12031235 cumulative_local_expert_recv_stats.has_value () ? cumulative_local_expert_recv_stats->data_ptr <int >() : nullptr ,
12041236 dispatch_wait_recv_cost_stats.has_value () ? dispatch_wait_recv_cost_stats->data_ptr <int64_t >() : nullptr ,
1237+ x_global_scale.has_value () ? x_global_scale->data_ptr <float >() : nullptr ,
12051238 buffer.dispatch_rdma_recv_data_buffer , buffer.dispatch_rdma_recv_count_buffer ,
12061239 buffer.dispatch_rdma_send_buffer ,
12071240 x.data_ptr (), topk_idx.data_ptr <int64_t >(),
12081241 next_clean_meta.first , next_clean_meta.second ,
12091242 num_tokens, hidden, num_max_dispatch_tokens_per_rank,
12101243 num_topk, num_experts, rank, num_ranks,
12111244 use_fp8, round_scale, use_ue8m0,
1245+ use_nvfp4, use_ue8m0_for_sf,
12121246 workspace, num_device_sms,
12131247 launch_stream, phases);
12141248 };
0 commit comments