Skip to content

Commit ef73fd9

Browse files
shifangxfzyzcjy
andauthored
Support nvfp4 low latency mode dispatch (#341)
* support NVFP4 data format in low latency dispatch * add support fp32_vec_to_e2m1 for __CUDA_ARCH__ less than 1000 * change threshold for diff * add debug message * change physical layout to be (l, m/128, k/4, 32, 4, 4) * use global scale for entire dispatch instead of per token scale * change test case * change some names and dtype: change from x_sf_scale to x_global_scales. change from use_ue8m0_for_sf to use_ue8m0_for_nvfp4_x_scale. set x_scale dtpye as torch::kFloat8_e4m3fn for if use_ue8m0_for_nvfp4_x_scale==False and torch::kUInt8 for use_ue8m0_for_nvfp4_x_scale==True. * support padding m * calibrate nvfp4 scale layout with grouped gemm * Fix wrong accuracy * Modify doc * modefy nvfp4 convert helper function in test * fix issue with align * modefy test * add msb_first flag in test helper function * add msb_first flag in test helper function * change test file name * change test run time --------- Co-authored-by: fzyzcjy <[email protected]>
1 parent 3f601f7 commit ef73fd9

File tree

8 files changed

+745
-45
lines changed

8 files changed

+745
-45
lines changed

csrc/deep_ep.cpp

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1122,8 +1122,10 @@ std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Te
11221122
Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
11231123
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
11241124
const std::optional<torch::Tensor>& dispatch_wait_recv_cost_stats,
1125+
const std::optional<torch::Tensor>& x_global_scale,
11251126
int num_max_dispatch_tokens_per_rank, int num_experts,
11261127
bool use_fp8, bool round_scale, bool use_ue8m0,
1128+
bool use_nvfp4, bool use_ue8m0_for_sf,
11271129
bool async, bool return_recv_hook) {
11281130
#ifndef DISABLE_NVSHMEM
11291131
EP_HOST_ASSERT(low_latency_mode);
@@ -1168,8 +1170,8 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
11681170
stream_wait(launch_stream, compute_stream);
11691171

11701172
// Allocate packed tensors
1171-
auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden},
1172-
x.options().dtype(use_fp8 ? torch::kFloat8_e4m3fn: torch::kBFloat16));
1173+
auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, use_nvfp4 ? hidden / 2 : hidden},
1174+
x.options().dtype(use_nvfp4 ? torch::kUInt8 : (use_fp8 ? torch::kFloat8_e4m3fn: torch::kBFloat16)));
11731175
auto packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA));
11741176
auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA));
11751177
auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA));
@@ -1179,6 +1181,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
11791181
void* packed_recv_x_scales_ptr = nullptr;
11801182
EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4");
11811183

1184+
EP_HOST_ASSERT(not (use_fp8 and use_nvfp4));
11821185
if (use_fp8) {
11831186
// TODO: support unaligned cases
11841187
EP_HOST_ASSERT(hidden % 512 == 0);
@@ -1192,6 +1195,35 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
11921195
}
11931196
packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2);
11941197
packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr();
1198+
}else if (use_nvfp4) {
1199+
constexpr int kNumPerChannels = 16;
1200+
constexpr int NUM_SF_ELEMS_PER_PACK = 4;
1201+
constexpr int mTileSize_dim_0 = 32;
1202+
constexpr int mTileSize_dim_1 = 4;
1203+
constexpr int mTileSize = mTileSize_dim_0 * mTileSize_dim_1;
1204+
1205+
assert(hidden % kNumPerChannels == 0);
1206+
auto l = num_local_experts;
1207+
auto m = num_ranks * num_max_dispatch_tokens_per_rank;
1208+
auto rm = (m + 127) / 128;
1209+
auto rk = (hidden + (kNumPerChannels * NUM_SF_ELEMS_PER_PACK) -1 ) / (kNumPerChannels * NUM_SF_ELEMS_PER_PACK);
1210+
// The physical layout is (l, rm, rk, 32, 4, 4).
1211+
if (use_ue8m0_for_sf) {
1212+
packed_recv_x_scales = torch::empty({l, rm, rk, 32, 4, 4},
1213+
torch::dtype(torch::kInt).device(torch::kCUDA));
1214+
} else {
1215+
packed_recv_x_scales = torch::empty({l, rm, rk, 32, 4, 4},
1216+
torch::dtype(torch::kFloat8_e4m3fn).device(torch::kCUDA));
1217+
}
1218+
// After permute, the logical shape is (32, 4, rm, 4, rk, l)
1219+
packed_recv_x_scales = packed_recv_x_scales.value().permute({3, 4, 1, 5, 2, 0});
1220+
1221+
// The physical layout is (l, m, k // 2).
1222+
// After permute, the logical shape is (m, k // 2, l).
1223+
packed_recv_x = packed_recv_x.permute({1, 2, 0});
1224+
1225+
packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr();
1226+
EP_HOST_ASSERT(packed_recv_x_scales_ptr != nullptr);
11951227
}
11961228

11971229
// Kernel launch
@@ -1202,13 +1234,15 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
12021234
packed_recv_count.data_ptr<int>(),
12031235
cumulative_local_expert_recv_stats.has_value() ? cumulative_local_expert_recv_stats->data_ptr<int>() : nullptr,
12041236
dispatch_wait_recv_cost_stats.has_value() ? dispatch_wait_recv_cost_stats->data_ptr<int64_t>() : nullptr,
1237+
x_global_scale.has_value() ? x_global_scale->data_ptr<float>() : nullptr,
12051238
buffer.dispatch_rdma_recv_data_buffer, buffer.dispatch_rdma_recv_count_buffer,
12061239
buffer.dispatch_rdma_send_buffer,
12071240
x.data_ptr(), topk_idx.data_ptr<int64_t>(),
12081241
next_clean_meta.first, next_clean_meta.second,
12091242
num_tokens, hidden, num_max_dispatch_tokens_per_rank,
12101243
num_topk, num_experts, rank, num_ranks,
12111244
use_fp8, round_scale, use_ue8m0,
1245+
use_nvfp4, use_ue8m0_for_sf,
12121246
workspace, num_device_sms,
12131247
launch_stream, phases);
12141248
};

csrc/deep_ep.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,10 @@ struct Buffer {
168168
low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
169169
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
170170
const std::optional<torch::Tensor>& dispatch_wait_recv_cost_stats,
171+
const std::optional<torch::Tensor>& x_global_scale,
171172
int num_max_dispatch_tokens_per_rank, int num_experts,
172173
bool use_fp8, bool round_scale, bool use_ue8m0,
174+
bool use_nvfp4, bool use_ue8m0_for_sf,
173175
bool async, bool return_recv_hook);
174176

175177
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>

csrc/kernels/api.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,12 +145,14 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
145145
int* packed_recv_count,
146146
int* cumulative_local_expert_recv_stats,
147147
int64_t* dispatch_wait_recv_cost_stats,
148+
const float* x_global_scale,
148149
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
149150
const void* x, const int64_t* topk_idx,
150151
int* next_clean, int num_next_clean_int,
151152
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
152153
int num_topk, int num_experts, int rank, int num_ranks,
153154
bool use_fp8, bool round_scale, bool use_ue8m0,
155+
bool use_nvfp4, bool use_ue8m0_for_sf,
154156
void* workspace, int num_device_sms,
155157
cudaStream_t stream, int phases);
156158

0 commit comments

Comments
 (0)