Skip to content

Commit 5cd59de

Browse files
committed
support NVFP4 data format in low latency dispatch
1 parent f0d34aa commit 5cd59de

File tree

7 files changed

+369
-39
lines changed

7 files changed

+369
-39
lines changed

csrc/deep_ep.cpp

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1087,12 +1087,14 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
10871087
#endif
10881088
}
10891089

1090-
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
1090+
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
10911091
Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
10921092
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
10931093
const std::optional<torch::Tensor>& dispatch_wait_recv_cost_stats,
1094+
const std::optional<torch::Tensor>& x_sf_scale,
10941095
int num_max_dispatch_tokens_per_rank, int num_experts,
10951096
bool use_fp8, bool round_scale, bool use_ue8m0,
1097+
bool use_nvfp4, bool use_ue8m0_for_nvfp4_sf,
10961098
bool async, bool return_recv_hook) {
10971099
#ifndef DISABLE_NVSHMEM
10981100
EP_HOST_ASSERT(low_latency_mode);
@@ -1137,15 +1139,18 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
11371139
stream_wait(launch_stream, compute_stream);
11381140

11391141
// Allocate packed tensors
1140-
auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden},
1141-
x.options().dtype(use_fp8 ? torch::kFloat8_e4m3fn: torch::kBFloat16));
1142+
constexpr int NUM_ELEMS_PER_PACK = 8;
1143+
auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, use_nvfp4 ? hidden / NUM_ELEMS_PER_PACK : hidden},
1144+
x.options().dtype(use_nvfp4 ? torch::kInt32 : (use_fp8 ? torch::kFloat8_e4m3fn: torch::kBFloat16)));
11421145
auto packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA));
11431146
auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA));
11441147
auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA));
11451148

11461149
// Allocate column-majored scales
11471150
auto packed_recv_x_scales = std::optional<torch::Tensor>();
11481151
void* packed_recv_x_scales_ptr = nullptr;
1152+
auto packed_recv_x_sf_scale = std::optional<torch::Tensor>();
1153+
void* packed_recv_x_sf_scale_ptr = nullptr;
11491154
EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4");
11501155

11511156
if (use_fp8) {
@@ -1161,23 +1166,34 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
11611166
}
11621167
packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2);
11631168
packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr();
1169+
}else if (use_nvfp4) {
1170+
constexpr int SF_VEC_SIZE = 16;
1171+
constexpr int NUM_SF_ELEMS_PER_PACK = 4;
1172+
packed_recv_x_scales = torch::empty({num_local_experts, hidden / (SF_VEC_SIZE * NUM_SF_ELEMS_PER_PACK), num_ranks * num_max_dispatch_tokens_per_rank},
1173+
torch::dtype(torch::kInt).device(torch::kCUDA));
1174+
packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2);
1175+
packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr();
1176+
packed_recv_x_sf_scale = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kFloat32).device(torch::kCUDA));
1177+
packed_recv_x_sf_scale_ptr = packed_recv_x_sf_scale->data_ptr();
11641178
}
11651179

11661180
// Kernel launch
11671181
auto next_clean_meta = next_buffer.clean_meta();
11681182
auto launcher = [=](int phases) {
1169-
internode_ll::dispatch(packed_recv_x.data_ptr(), packed_recv_x_scales_ptr,
1183+
internode_ll::dispatch(packed_recv_x.data_ptr(), packed_recv_x_scales_ptr, packed_recv_x_sf_scale_ptr,
11701184
packed_recv_src_info.data_ptr<int>(), packed_recv_layout_range.data_ptr<int64_t>(),
11711185
packed_recv_count.data_ptr<int>(),
11721186
cumulative_local_expert_recv_stats.has_value() ? cumulative_local_expert_recv_stats->data_ptr<int>() : nullptr,
11731187
dispatch_wait_recv_cost_stats.has_value() ? dispatch_wait_recv_cost_stats->data_ptr<int64_t>() : nullptr,
1188+
x_sf_scale.has_value() ? x_sf_scale->data_ptr<float>() : nullptr,
11741189
buffer.dispatch_rdma_recv_data_buffer, buffer.dispatch_rdma_recv_count_buffer,
11751190
buffer.dispatch_rdma_send_buffer,
11761191
x.data_ptr(), topk_idx.data_ptr<int64_t>(),
11771192
next_clean_meta.first, next_clean_meta.second,
11781193
num_tokens, hidden, num_max_dispatch_tokens_per_rank,
11791194
num_topk, num_experts, rank, num_ranks,
11801195
use_fp8, round_scale, use_ue8m0,
1196+
use_nvfp4, use_ue8m0_for_nvfp4_sf,
11811197
workspace, num_device_sms,
11821198
launch_stream, phases);
11831199
};
@@ -1199,7 +1215,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
11991215
recv_hook = [=]() { launcher(LOW_LATENCY_RECV_PHASE); };
12001216

12011217
// Return values
1202-
return {packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, recv_hook};
1218+
return {packed_recv_x, packed_recv_x_scales, packed_recv_x_sf_scale, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, recv_hook};
12031219
#else
12041220
EP_HOST_ASSERT(false and "NVSHMEM is disabled during compilation");
12051221
return {};

csrc/deep_ep.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,12 +143,14 @@ struct Buffer {
143143

144144
void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts);
145145

146-
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
146+
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
147147
low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
148148
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
149149
const std::optional<torch::Tensor>& dispatch_wait_recv_cost_stats,
150+
const std::optional<torch::Tensor>& x_sf_scale,
150151
int num_max_dispatch_tokens_per_rank, int num_experts,
151152
bool use_fp8, bool round_scale, bool use_ue8m0,
153+
bool use_nvfp4, bool use_ue8m0_for_nvfp4_sf,
152154
bool async, bool return_recv_hook);
153155

154156
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>

csrc/kernels/api.cuh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,17 +139,19 @@ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
139139
int* clean_1, int num_clean_int_1,
140140
cudaStream_t stream);
141141

142-
void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
142+
void dispatch(void* packed_recv_x, void* packed_recv_x_scales, void* packed_recv_x_sf_scale,
143143
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
144144
int* packed_recv_count,
145145
int* cumulative_local_expert_recv_stats,
146146
int64_t* dispatch_wait_recv_cost_stats,
147+
const float* x_sf_scale,
147148
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
148149
const void* x, const int64_t* topk_idx,
149150
int* next_clean, int num_next_clean_int,
150151
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
151152
int num_topk, int num_experts, int rank, int num_ranks,
152153
bool use_fp8, bool round_scale, bool use_ue8m0,
154+
bool use_nvfp4, bool use_ue8m0_for_nvfp4_sf,
153155
void* workspace, int num_device_sms,
154156
cudaStream_t stream, int phases);
155157

0 commit comments

Comments
 (0)