Skip to content

Commit 63ad6b4

Browse files
committed
fix issue with nvfp4 data formate
1 parent 57a670a commit 63ad6b4

File tree

8 files changed

+495
-87
lines changed

8 files changed

+495
-87
lines changed

csrc/deep_ep.cpp

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1091,10 +1091,10 @@ std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Te
10911091
Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
10921092
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
10931093
const std::optional<torch::Tensor>& dispatch_wait_recv_cost_stats,
1094-
const std::optional<torch::Tensor>& x_global_scales,
1094+
const std::optional<torch::Tensor>& x_sf_scale,
10951095
int num_max_dispatch_tokens_per_rank, int num_experts,
10961096
bool use_fp8, bool round_scale, bool use_ue8m0,
1097-
bool use_nvfp4, bool use_ue8m0_for_nvfp4_x_scale,
1097+
bool use_nvfp4, bool use_ue8m0_for_sf,
10981098
bool async, bool return_recv_hook) {
10991099
#ifndef DISABLE_NVSHMEM
11001100
EP_HOST_ASSERT(low_latency_mode);
@@ -1139,9 +1139,8 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
11391139
stream_wait(launch_stream, compute_stream);
11401140

11411141
// Allocate packed tensors
1142-
constexpr int NUM_ELEMS_PER_PACK = 8;
1143-
auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, use_nvfp4 ? hidden / NUM_ELEMS_PER_PACK : hidden},
1144-
x.options().dtype(use_nvfp4 ? torch::kUInt32 : (use_fp8 ? torch::kFloat8_e4m3fn: torch::kBFloat16)));
1142+
auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, use_nvfp4 ? hidden / 2 : hidden},
1143+
x.options().dtype(use_nvfp4 ? torch::kUInt8 : (use_fp8 ? torch::kFloat8_e4m3fn: torch::kBFloat16)));
11451144
auto packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA));
11461145
auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA));
11471146
auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA));
@@ -1172,19 +1171,26 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
11721171
constexpr int mTileSize_dim_1 = 4;
11731172
constexpr int mTileSize = mTileSize_dim_0 * mTileSize_dim_1;
11741173

1174+
assert(hidden % kNumPerChannels == 0);
11751175
auto l = num_local_experts;
11761176
auto m = num_ranks * num_max_dispatch_tokens_per_rank;
11771177
auto rm = (m + 127) / 128;
1178-
auto rk = hidden / (kNumPerChannels * NUM_SF_ELEMS_PER_PACK);
1179-
auto scale_dtype = use_ue8m0_for_nvfp4_x_scale ?
1180-
torch::dtype(torch::kUInt8) :
1181-
torch::dtype(torch::kFloat8_e4m3fn);
1178+
auto rk = (hidden + (kNumPerChannels * NUM_SF_ELEMS_PER_PACK) -1 ) / (kNumPerChannels * NUM_SF_ELEMS_PER_PACK);
11821179
// The physical layout is (l, rm, rk, 32, 4, 4).
1183-
packed_recv_x_scales = torch::empty({l, rm, rk, 32, 4, 4},
1184-
scale_dtype.device(torch::kCUDA));
1180+
if (use_ue8m0_for_sf) {
1181+
packed_recv_x_scales = torch::empty({l, rm, rk, 32, 4, 4},
1182+
torch::dtype(torch::kInt).device(torch::kCUDA));
1183+
} else {
1184+
packed_recv_x_scales = torch::empty({l, rm, rk, 32, 4, 4},
1185+
torch::dtype(torch::kFloat8_e4m3fn).device(torch::kCUDA));
1186+
}
11851187
// After permute, the logical shape is (32, 4, rm, 4, rk, l)
11861188
packed_recv_x_scales = packed_recv_x_scales.value().permute({3, 4, 1, 5, 2, 0});
11871189

1190+
// The physical layout is (l, m, k // 2).
1191+
// After permute, the logical shape is (m, k // 2, l).
1192+
packed_recv_x = packed_recv_x.permute({1, 2, 0});
1193+
11881194
packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr();
11891195
EP_HOST_ASSERT(packed_recv_x_scales_ptr != nullptr);
11901196
}
@@ -1197,15 +1203,15 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
11971203
packed_recv_count.data_ptr<int>(),
11981204
cumulative_local_expert_recv_stats.has_value() ? cumulative_local_expert_recv_stats->data_ptr<int>() : nullptr,
11991205
dispatch_wait_recv_cost_stats.has_value() ? dispatch_wait_recv_cost_stats->data_ptr<int64_t>() : nullptr,
1200-
x_global_scales.has_value() ? x_global_scales->data_ptr<float>() : nullptr,
1206+
x_sf_scale.has_value() ? x_sf_scale->data_ptr<float>() : nullptr,
12011207
buffer.dispatch_rdma_recv_data_buffer, buffer.dispatch_rdma_recv_count_buffer,
12021208
buffer.dispatch_rdma_send_buffer,
12031209
x.data_ptr(), topk_idx.data_ptr<int64_t>(),
12041210
next_clean_meta.first, next_clean_meta.second,
12051211
num_tokens, hidden, num_max_dispatch_tokens_per_rank,
12061212
num_topk, num_experts, rank, num_ranks,
12071213
use_fp8, round_scale, use_ue8m0,
1208-
use_nvfp4, use_ue8m0_for_nvfp4_x_scale,
1214+
use_nvfp4, use_ue8m0_for_sf,
12091215
workspace, num_device_sms,
12101216
launch_stream, phases);
12111217
};

csrc/deep_ep.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,10 +147,10 @@ struct Buffer {
147147
low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
148148
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
149149
const std::optional<torch::Tensor>& dispatch_wait_recv_cost_stats,
150-
const std::optional<torch::Tensor>& x_global_scales,
150+
const std::optional<torch::Tensor>& x_sf_scale,
151151
int num_max_dispatch_tokens_per_rank, int num_experts,
152152
bool use_fp8, bool round_scale, bool use_ue8m0,
153-
bool use_nvfp4, bool use_ue8m0_for_nvfp4_x_scale,
153+
bool use_nvfp4, bool use_ue8m0_for_sf,
154154
bool async, bool return_recv_hook);
155155

156156
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>

csrc/kernels/api.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,14 +144,14 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
144144
int* packed_recv_count,
145145
int* cumulative_local_expert_recv_stats,
146146
int64_t* dispatch_wait_recv_cost_stats,
147-
const float* x_global_scales,
147+
const float* x_sf_scale,
148148
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
149149
const void* x, const int64_t* topk_idx,
150150
int* next_clean, int num_next_clean_int,
151151
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
152152
int num_topk, int num_experts, int rank, int num_ranks,
153153
bool use_fp8, bool round_scale, bool use_ue8m0,
154-
bool use_nvfp4, bool use_ue8m0_for_nvfp4_x_scale,
154+
bool use_nvfp4, bool use_ue8m0_for_sf,
155155
void* workspace, int num_device_sms,
156156
cudaStream_t stream, int phases);
157157

csrc/kernels/internode_ll.cu

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
200200
int* packed_recv_count,
201201
int* cumulative_local_expert_recv_stats,
202202
int64_t* dispatch_wait_recv_cost_stats,
203-
const float* x_global_scales,
203+
const float* x_sf_scale,
204204
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
205205
const void* x, const int64_t* topk_idx,
206206
int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert,
@@ -275,8 +275,8 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
275275
float SFScaleVal = 1.0f;
276276
if constexpr (kUseNVFP4) {
277277
// Get scaling value;
278-
EP_DEVICE_ASSERT(x_global_scales != nullptr);
279-
SFScaleVal = *(static_cast<const float*>(x_global_scales));
278+
EP_DEVICE_ASSERT(x_sf_scale != nullptr);
279+
SFScaleVal = *(static_cast<const float*>(x_sf_scale));
280280
}
281281

282282
// FP8 or NVFP4 cast
@@ -517,21 +517,28 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
517517
recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;
518518
}
519519
} else if constexpr (kUseNVFP4) {
520-
// The physical layout is (l, rm, rk, 32, 4, 4).
520+
// The physical layout is (l, rm, rk, 32, 4, 4)
521521
const auto src_scales = reinterpret_cast<uint8_t*>(reinterpret_cast<uint8_t*>(src_data) + hidden_bytes);
522522
const auto num_elems_per_pack = static_cast<int>(sizeof(packed_t) / sizeof(scale_t));
523523
const auto token_idx = recv_token_begin_idx + i;
524-
const auto token_stride = num_scales * sizeof(scale_t);
525-
const auto pack_stride = num_elems_per_pack;
526-
const auto rm = token_idx / 128;
527-
const auto rm_res = token_idx % 128;
524+
525+
const auto padded_k = (kHidden + (kNumPerChannels * num_elems_per_pack) -1 ) / (kNumPerChannels * num_elems_per_pack);
526+
const auto dim0_stride = 128 * padded_k / kNumPerChannels;
527+
const auto dim1_stride = 128 * num_elems_per_pack;
528+
const auto dim2_stride = 4 * num_elems_per_pack;
529+
const auto dim3_stride = num_elems_per_pack;
530+
531+
const auto dim0_offset = token_idx / 128;
532+
const auto dim2_offset = (token_idx % 128) % 32;
533+
const auto dim3_offset = (token_idx % 128) / 32;
534+
528535
#pragma unroll
529536
for (int j = lane_id; j < num_scales; j += 32) {
530-
const auto pack_idx = j / num_elems_per_pack;
531-
const auto elem_idx = j % num_elems_per_pack;
537+
const auto dim1_offset = j / num_elems_per_pack;
538+
const auto dim4_offset = j % num_elems_per_pack;
532539
auto scale = ld_nc_global(src_scales + j);
533-
// recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;
534-
recv_x_scales[rm * token_stride * 128 + pack_idx * pack_stride * 128 + rm_res * pack_stride + elem_idx] = scale;
540+
const auto offset = dim0_offset * dim0_stride + dim1_offset * dim1_stride + dim2_offset * dim2_stride + dim3_offset * dim3_stride + dim4_offset;
541+
recv_x_scales[offset] = scale;
535542
}
536543
}
537544
}
@@ -543,14 +550,14 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
543550
int* packed_recv_count,
544551
int* cumulative_local_expert_recv_stats,
545552
int64_t* dispatch_wait_recv_cost_stats,
546-
const float* x_global_scales,
553+
const float* x_sf_scale,
547554
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
548555
const void* x, const int64_t* topk_idx,
549556
int* next_clean, int num_next_clean_int,
550557
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
551558
int num_topk, int num_experts, int rank, int num_ranks,
552559
bool use_fp8, bool round_scale, bool use_ue8m0,
553-
bool use_nvfp4, bool use_ue8m0_for_nvfp4_x_scale,
560+
bool use_nvfp4, bool use_ue8m0_for_sf,
554561
void* workspace, int num_device_sms,
555562
cudaStream_t stream, int phases) {
556563
constexpr int kNumMaxTopK = 9;
@@ -578,17 +585,17 @@ if (use_fp8 and not use_ue8m0) \
578585
dispatch_func = dispatch<true, false, false, false, hidden>; \
579586
if (use_fp8 and use_ue8m0) \
580587
dispatch_func = dispatch<true, true, false, false, hidden>; \
581-
if (use_nvfp4 and not use_ue8m0_for_nvfp4_x_scale) \
588+
if (use_nvfp4 and not use_ue8m0_for_sf) \
582589
dispatch_func = dispatch<false, false, true, false, hidden>; \
583-
if (use_nvfp4 and use_ue8m0_for_nvfp4_x_scale) \
590+
if (use_nvfp4 and use_ue8m0_for_sf) \
584591
dispatch_func = dispatch<false, false, true, true, hidden>; \
585592
LAUNCH_KERNEL(&cfg, dispatch_func, \
586593
packed_recv_x, packed_recv_x_scales, \
587594
packed_recv_src_info, packed_recv_layout_range, \
588595
packed_recv_count, \
589596
cumulative_local_expert_recv_stats, \
590597
dispatch_wait_recv_cost_stats, \
591-
x_global_scales, \
598+
x_sf_scale, \
592599
rdma_recv_x, rdma_recv_count, rdma_x, \
593600
x, topk_idx, \
594601
atomic_counter_per_expert, atomic_finish_counter_per_expert, \

deep_ep/buffer.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -528,9 +528,9 @@ def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
528528
num_max_dispatch_tokens_per_rank: int, num_experts: int,
529529
cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None,
530530
dispatch_wait_recv_cost_stats: Optional[torch.Tensor] = None,
531-
x_global_scales: Optional[torch.Tensor] = None,
531+
x_sf_scale: Optional[torch.Tensor] = None,
532532
use_fp8: bool = True, round_scale: bool = False, use_ue8m0: bool = False,
533-
use_nvfp4: bool = False, use_ue8m0_for_nvfp4_x_scale: bool = False,
533+
use_nvfp4: bool = False, use_ue8m0_for_sf: bool = False,
534534
async_finish: bool = False, return_recv_hook: bool = False) -> \
535535
Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, Tuple, EventOverlap, Callable]:
536536
"""
@@ -553,12 +553,12 @@ def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
553553
dispatch_wait_recv_cost_stats: a cumulative time spent waiting to receive each token tensor for statistics,
554554
which should have shape `[num_ranks, num_ranks]` and be typed as `torch.int64`.
555555
This is useful for detecting and pre-cisely localizing slow anomalies.
556-
x_global_scales: a float32 tensor with dim() == 0, the scaling factors for the entire dispatch.
556+
x_sf_scale: a float32 tensor with dim() == 0, the scaling factors for the entire dispatch.
557557
use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors.
558558
round_scale: whether round the scaling factors into power of 2.
559559
use_ue8m0: whether use UE8M0 as scaling factor format (available only with `round_scale=True`).
560560
use_nvfp4: whether to enable NVFP4 casting, with this, the received data will be a tuple of NVFP4 tensor and scaling factors.
561-
use_ue8m0_for_nvfp4_x_scale: whether use UE8M0 as NVFP4 scaling factor format (available only with `use_nvfp4=True`).
561+
use_ue8m0_for_sf: whether use UE8M0 as NVFP4 scaling factor format (available only with `use_nvfp4=True`).
562562
async_finish: the current stream will not wait for the communication kernels to be finished if set.
563563
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
564564
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
@@ -591,17 +591,17 @@ def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
591591
self.runtime.low_latency_dispatch(x, topk_idx,
592592
cumulative_local_expert_recv_stats,
593593
dispatch_wait_recv_cost_stats,
594-
x_global_scales,
594+
x_sf_scale,
595595
num_max_dispatch_tokens_per_rank, num_experts,
596596
use_fp8, round_scale, use_ue8m0,
597-
use_nvfp4, use_ue8m0_for_nvfp4_x_scale,
597+
use_nvfp4, use_ue8m0_for_sf,
598598
async_finish, return_recv_hook)
599599
handle = (packed_recv_src_info, packed_recv_layout_range, num_max_dispatch_tokens_per_rank, x.size(1), num_experts)
600600
tensors_to_record = (x, topk_idx,
601601
packed_recv_x, packed_recv_x_scales, packed_recv_count,
602602
packed_recv_src_info, packed_recv_layout_range,
603603
cumulative_local_expert_recv_stats,
604-
x_global_scales)
604+
x_sf_scale)
605605
if use_fp8 or use_nvfp4:
606606
packed_recv_x = (packed_recv_x, packed_recv_x_scales)
607607
return packed_recv_x, packed_recv_count, handle, \

0 commit comments

Comments
 (0)