@@ -1087,12 +1087,14 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
1087
1087
#endif
1088
1088
}
1089
1089
1090
- std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void ()>>>
1090
+ std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void ()>>>
1091
1091
Buffer::low_latency_dispatch (const torch::Tensor& x, const torch::Tensor& topk_idx,
1092
1092
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
1093
1093
const std::optional<torch::Tensor>& dispatch_wait_recv_cost_stats,
1094
+ const std::optional<torch::Tensor>& x_sf_scale,
1094
1095
int num_max_dispatch_tokens_per_rank, int num_experts,
1095
1096
bool use_fp8, bool round_scale, bool use_ue8m0,
1097
+ bool use_nvfp4, bool use_ue8m0_for_nvfp4_sf,
1096
1098
bool async, bool return_recv_hook) {
1097
1099
#ifndef DISABLE_NVSHMEM
1098
1100
EP_HOST_ASSERT (low_latency_mode);
@@ -1137,15 +1139,18 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
1137
1139
stream_wait (launch_stream, compute_stream);
1138
1140
1139
1141
// Allocate packed tensors
1140
- auto packed_recv_x = torch::empty ({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden},
1141
- x.options ().dtype (use_fp8 ? torch::kFloat8_e4m3fn : torch::kBFloat16 ));
1142
+ constexpr int NUM_ELEMS_PER_PACK = 8 ;
1143
+ auto packed_recv_x = torch::empty ({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, use_nvfp4 ? hidden / NUM_ELEMS_PER_PACK : hidden},
1144
+ x.options ().dtype (use_nvfp4 ? torch::kInt32 : (use_fp8 ? torch::kFloat8_e4m3fn : torch::kBFloat16 )));
1142
1145
auto packed_recv_src_info = torch::empty ({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype (torch::kInt32 ).device (torch::kCUDA ));
1143
1146
auto packed_recv_layout_range = torch::empty ({num_local_experts, num_ranks}, torch::dtype (torch::kInt64 ).device (torch::kCUDA ));
1144
1147
auto packed_recv_count = torch::empty ({num_local_experts}, torch::dtype (torch::kInt32 ).device (torch::kCUDA ));
1145
1148
1146
1149
// Allocate column-majored scales
1147
1150
auto packed_recv_x_scales = std::optional<torch::Tensor>();
1148
1151
void * packed_recv_x_scales_ptr = nullptr ;
1152
+ auto packed_recv_x_sf_scale = std::optional<torch::Tensor>();
1153
+ void * packed_recv_x_sf_scale_ptr = nullptr ;
1149
1154
EP_HOST_ASSERT ((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and " TMA requires the number of tokens to be multiple of 4" );
1150
1155
1151
1156
if (use_fp8) {
@@ -1161,23 +1166,34 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
1161
1166
}
1162
1167
packed_recv_x_scales = torch::transpose (packed_recv_x_scales.value (), 1 , 2 );
1163
1168
packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr ();
1169
+ }else if (use_nvfp4) {
1170
+ constexpr int SF_VEC_SIZE = 16 ;
1171
+ constexpr int NUM_SF_ELEMS_PER_PACK = 4 ;
1172
+ packed_recv_x_scales = torch::empty ({num_local_experts, hidden / (SF_VEC_SIZE * NUM_SF_ELEMS_PER_PACK), num_ranks * num_max_dispatch_tokens_per_rank},
1173
+ torch::dtype (torch::kInt ).device (torch::kCUDA ));
1174
+ packed_recv_x_scales = torch::transpose (packed_recv_x_scales.value (), 1 , 2 );
1175
+ packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr ();
1176
+ packed_recv_x_sf_scale = torch::empty ({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype (torch::kFloat32 ).device (torch::kCUDA ));
1177
+ packed_recv_x_sf_scale_ptr = packed_recv_x_sf_scale->data_ptr ();
1164
1178
}
1165
1179
1166
1180
// Kernel launch
1167
1181
auto next_clean_meta = next_buffer.clean_meta ();
1168
1182
auto launcher = [=](int phases) {
1169
- internode_ll::dispatch (packed_recv_x.data_ptr (), packed_recv_x_scales_ptr,
1183
+ internode_ll::dispatch (packed_recv_x.data_ptr (), packed_recv_x_scales_ptr, packed_recv_x_sf_scale_ptr,
1170
1184
packed_recv_src_info.data_ptr <int >(), packed_recv_layout_range.data_ptr <int64_t >(),
1171
1185
packed_recv_count.data_ptr <int >(),
1172
1186
cumulative_local_expert_recv_stats.has_value () ? cumulative_local_expert_recv_stats->data_ptr <int >() : nullptr ,
1173
1187
dispatch_wait_recv_cost_stats.has_value () ? dispatch_wait_recv_cost_stats->data_ptr <int64_t >() : nullptr ,
1188
+ x_sf_scale.has_value () ? x_sf_scale->data_ptr <float >() : nullptr ,
1174
1189
buffer.dispatch_rdma_recv_data_buffer , buffer.dispatch_rdma_recv_count_buffer ,
1175
1190
buffer.dispatch_rdma_send_buffer ,
1176
1191
x.data_ptr (), topk_idx.data_ptr <int64_t >(),
1177
1192
next_clean_meta.first , next_clean_meta.second ,
1178
1193
num_tokens, hidden, num_max_dispatch_tokens_per_rank,
1179
1194
num_topk, num_experts, rank, num_ranks,
1180
1195
use_fp8, round_scale, use_ue8m0,
1196
+ use_nvfp4, use_ue8m0_for_nvfp4_sf,
1181
1197
workspace, num_device_sms,
1182
1198
launch_stream, phases);
1183
1199
};
@@ -1199,7 +1215,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
1199
1215
recv_hook = [=]() { launcher (LOW_LATENCY_RECV_PHASE); };
1200
1216
1201
1217
// Return values
1202
- return {packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, recv_hook};
1218
+ return {packed_recv_x, packed_recv_x_scales, packed_recv_x_sf_scale, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, recv_hook};
1203
1219
#else
1204
1220
EP_HOST_ASSERT (false and " NVSHMEM is disabled during compilation" );
1205
1221
return {};
0 commit comments