Skip to content

Commit 82147f2

Browse files
committed
change some names and dtype:
change from x_sf_scale to x_global_scales. change from use_ue8m0_for_sf to use_ue8m0_for_nvfp4_x_scale. set x_scale dtpye as torch::kFloat8_e4m3fn for if use_ue8m0_for_nvfp4_x_scale==False and torch::kUInt8 for use_ue8m0_for_nvfp4_x_scale==True.
1 parent ccf4eaf commit 82147f2

File tree

7 files changed

+54
-45
lines changed

7 files changed

+54
-45
lines changed

csrc/deep_ep.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1091,10 +1091,10 @@ std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Te
10911091
Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
10921092
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
10931093
const std::optional<torch::Tensor>& dispatch_wait_recv_cost_stats,
1094-
const std::optional<torch::Tensor>& x_sf_scale,
1094+
const std::optional<torch::Tensor>& x_global_scales,
10951095
int num_max_dispatch_tokens_per_rank, int num_experts,
10961096
bool use_fp8, bool round_scale, bool use_ue8m0,
1097-
bool use_nvfp4, bool use_ue8m0_for_sf,
1097+
bool use_nvfp4, bool use_ue8m0_for_nvfp4_x_scale,
10981098
bool async, bool return_recv_hook) {
10991099
#ifndef DISABLE_NVSHMEM
11001100
EP_HOST_ASSERT(low_latency_mode);
@@ -1176,9 +1176,12 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
11761176
auto m = num_ranks * num_max_dispatch_tokens_per_rank;
11771177
auto rm = (m + 127) / 128;
11781178
auto rk = hidden / (kNumPerChannels * NUM_SF_ELEMS_PER_PACK);
1179+
auto scale_dtype = use_ue8m0_for_nvfp4_x_scale ?
1180+
torch::dtype(torch::kUInt8) :
1181+
torch::dtype(torch::kFloat8_e4m3fn);
11791182
// The physical layout is (l, rm, rk, 32, 4, 4).
11801183
packed_recv_x_scales = torch::empty({l, rm, rk, 32, 4, 4},
1181-
torch::dtype(torch::kUInt8).device(torch::kCUDA));
1184+
scale_dtype.device(torch::kCUDA));
11821185
// After permute, the logical shape is (32, 4, rm, 4, rk, l)
11831186
packed_recv_x_scales = packed_recv_x_scales.value().permute({3, 4, 1, 5, 2, 0});
11841187

@@ -1194,15 +1197,15 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
11941197
packed_recv_count.data_ptr<int>(),
11951198
cumulative_local_expert_recv_stats.has_value() ? cumulative_local_expert_recv_stats->data_ptr<int>() : nullptr,
11961199
dispatch_wait_recv_cost_stats.has_value() ? dispatch_wait_recv_cost_stats->data_ptr<int64_t>() : nullptr,
1197-
x_sf_scale.has_value() ? x_sf_scale->data_ptr<float>() : nullptr,
1200+
x_global_scales.has_value() ? x_global_scales->data_ptr<float>() : nullptr,
11981201
buffer.dispatch_rdma_recv_data_buffer, buffer.dispatch_rdma_recv_count_buffer,
11991202
buffer.dispatch_rdma_send_buffer,
12001203
x.data_ptr(), topk_idx.data_ptr<int64_t>(),
12011204
next_clean_meta.first, next_clean_meta.second,
12021205
num_tokens, hidden, num_max_dispatch_tokens_per_rank,
12031206
num_topk, num_experts, rank, num_ranks,
12041207
use_fp8, round_scale, use_ue8m0,
1205-
use_nvfp4, use_ue8m0_for_sf,
1208+
use_nvfp4, use_ue8m0_for_nvfp4_x_scale,
12061209
workspace, num_device_sms,
12071210
launch_stream, phases);
12081211
};

csrc/deep_ep.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,10 +147,10 @@ struct Buffer {
147147
low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
148148
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
149149
const std::optional<torch::Tensor>& dispatch_wait_recv_cost_stats,
150-
const std::optional<torch::Tensor>& x_sf_scale,
150+
const std::optional<torch::Tensor>& x_global_scales,
151151
int num_max_dispatch_tokens_per_rank, int num_experts,
152152
bool use_fp8, bool round_scale, bool use_ue8m0,
153-
bool use_nvfp4, bool use_ue8m0_for_sf,
153+
bool use_nvfp4, bool use_ue8m0_for_nvfp4_x_scale,
154154
bool async, bool return_recv_hook);
155155

156156
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>

csrc/kernels/api.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,14 +144,14 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
144144
int* packed_recv_count,
145145
int* cumulative_local_expert_recv_stats,
146146
int64_t* dispatch_wait_recv_cost_stats,
147-
const float* x_sf_scale,
147+
const float* x_global_scales,
148148
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
149149
const void* x, const int64_t* topk_idx,
150150
int* next_clean, int num_next_clean_int,
151151
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
152152
int num_topk, int num_experts, int rank, int num_ranks,
153153
bool use_fp8, bool round_scale, bool use_ue8m0,
154-
bool use_nvfp4, bool use_ue8m0_for_sf,
154+
bool use_nvfp4, bool use_ue8m0_for_nvfp4_x_scale,
155155
void* workspace, int num_device_sms,
156156
cudaStream_t stream, int phases);
157157

csrc/kernels/internode_ll.cu

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,11 @@ __device__ inline uint8_t float_to_e2m1(float f) {
8181
return (sign << 3) | (exp << 1) | mant;
8282
}
8383

84-
8584
// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
8685
inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) {
87-
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
86+
// PTX instructions used here requires sm100a.
87+
#if CUDA_VERSION >= 12080
88+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) && __CUDA_ARCH_HAS_FEATURE__(SM100_ALL)
8889
uint32_t val;
8990
asm volatile(
9091
"{\n"
@@ -99,13 +100,16 @@ inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) {
99100
"mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
100101
"}"
101102
: "=r"(val)
102-
: "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y), "f"(array[2].x),
103-
"f"(array[2].y), "f"(array[3].x), "f"(array[3].y));
103+
: "f"(array[0].x),
104+
"f"(array[0].y),
105+
"f"(array[1].x),
106+
"f"(array[1].y),
107+
"f"(array[2].x),
108+
"f"(array[2].y),
109+
"f"(array[3].x),
110+
"f"(array[3].y));
104111
return val;
105112
#else
106-
#if !(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000))
107-
#pragma message("warning: this architecture does not support cvt.rn.satfinite.e2m1x2.f32, use float_to_e2m1 instead.")
108-
#endif
109113
uint32_t val = 0;
110114
float2* data = reinterpret_cast<float2*>(&array[0]);
111115
for (int i = 0; i < 4; ++i) {
@@ -114,7 +118,8 @@ inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) {
114118
}
115119
return val;
116120
#endif
117-
}
121+
#endif
122+
}
118123

119124
constexpr int CVT_ELTS_PER_THREAD = 8;
120125
// Quantizes the provided PackedVec into the uint32_t output
@@ -195,7 +200,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
195200
int* packed_recv_count,
196201
int* cumulative_local_expert_recv_stats,
197202
int64_t* dispatch_wait_recv_cost_stats,
198-
const float* x_sf_scale,
203+
const float* x_global_scales,
199204
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
200205
const void* x, const int64_t* topk_idx,
201206
int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert,
@@ -270,8 +275,8 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
270275
float SFScaleVal = 1.0f;
271276
if constexpr (kUseNVFP4) {
272277
// Get scaling value;
273-
EP_DEVICE_ASSERT(x_sf_scale != nullptr);
274-
SFScaleVal = *(static_cast<const float*>(x_sf_scale));
278+
EP_DEVICE_ASSERT(x_global_scales != nullptr);
279+
SFScaleVal = *(static_cast<const float*>(x_global_scales));
275280
}
276281

277282
// FP8 or NVFP4 cast
@@ -537,14 +542,14 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
537542
int* packed_recv_count,
538543
int* cumulative_local_expert_recv_stats,
539544
int64_t* dispatch_wait_recv_cost_stats,
540-
const float* x_sf_scale,
545+
const float* x_global_scales,
541546
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
542547
const void* x, const int64_t* topk_idx,
543548
int* next_clean, int num_next_clean_int,
544549
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
545550
int num_topk, int num_experts, int rank, int num_ranks,
546551
bool use_fp8, bool round_scale, bool use_ue8m0,
547-
bool use_nvfp4, bool use_ue8m0_for_sf,
552+
bool use_nvfp4, bool use_ue8m0_for_nvfp4_x_scale,
548553
void* workspace, int num_device_sms,
549554
cudaStream_t stream, int phases) {
550555
constexpr int kNumMaxTopK = 9;
@@ -572,17 +577,17 @@ if (use_fp8 and not use_ue8m0) \
572577
dispatch_func = dispatch<true, false, false, false, hidden>; \
573578
if (use_fp8 and use_ue8m0) \
574579
dispatch_func = dispatch<true, true, false, false, hidden>; \
575-
if (use_nvfp4 and not use_ue8m0_for_sf) \
580+
if (use_nvfp4 and not use_ue8m0_for_nvfp4_x_scale) \
576581
dispatch_func = dispatch<false, false, true, false, hidden>; \
577-
if (use_nvfp4 and use_ue8m0_for_sf) \
582+
if (use_nvfp4 and use_ue8m0_for_nvfp4_x_scale) \
578583
dispatch_func = dispatch<false, false, true, true, hidden>; \
579584
LAUNCH_KERNEL(&cfg, dispatch_func, \
580585
packed_recv_x, packed_recv_x_scales, \
581586
packed_recv_src_info, packed_recv_layout_range, \
582587
packed_recv_count, \
583588
cumulative_local_expert_recv_stats, \
584589
dispatch_wait_recv_cost_stats, \
585-
x_sf_scale, \
590+
x_global_scales, \
586591
rdma_recv_x, rdma_recv_count, rdma_x, \
587592
x, topk_idx, \
588593
atomic_counter_per_expert, atomic_finish_counter_per_expert, \

deep_ep/buffer.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -528,9 +528,9 @@ def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
528528
num_max_dispatch_tokens_per_rank: int, num_experts: int,
529529
cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None,
530530
dispatch_wait_recv_cost_stats: Optional[torch.Tensor] = None,
531-
x_sf_scale: Optional[torch.Tensor] = None,
531+
x_global_scales: Optional[torch.Tensor] = None,
532532
use_fp8: bool = True, round_scale: bool = False, use_ue8m0: bool = False,
533-
use_nvfp4: bool = False, use_ue8m0_for_sf: bool = False,
533+
use_nvfp4: bool = False, use_ue8m0_for_nvfp4_x_scale: bool = False,
534534
async_finish: bool = False, return_recv_hook: bool = False) -> \
535535
Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, Tuple, EventOverlap, Callable]:
536536
"""
@@ -553,12 +553,12 @@ def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
553553
dispatch_wait_recv_cost_stats: a cumulative time spent waiting to receive each token tensor for statistics,
554554
which should have shape `[num_ranks, num_ranks]` and be typed as `torch.int64`.
555555
This is useful for detecting and pre-cisely localizing slow anomalies.
556-
x_sf_scale: a float32 tensor with dim() == 0, the scaling factors for the entire dispatch.
556+
x_global_scales: a float32 tensor with dim() == 0, the scaling factors for the entire dispatch.
557557
use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors.
558558
round_scale: whether round the scaling factors into power of 2.
559559
use_ue8m0: whether use UE8M0 as scaling factor format (available only with `round_scale=True`).
560560
use_nvfp4: whether to enable NVFP4 casting, with this, the received data will be a tuple of NVFP4 tensor and scaling factors.
561-
use_ue8m0_for_sf: whether use UE8M0 as NVFP4 scaling factor format (available only with `use_nvfp4=True`).
561+
use_ue8m0_for_nvfp4_x_scale: whether use UE8M0 as NVFP4 scaling factor format (available only with `use_nvfp4=True`).
562562
async_finish: the current stream will not wait for the communication kernels to be finished if set.
563563
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
564564
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
@@ -591,17 +591,17 @@ def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
591591
self.runtime.low_latency_dispatch(x, topk_idx,
592592
cumulative_local_expert_recv_stats,
593593
dispatch_wait_recv_cost_stats,
594-
x_sf_scale,
594+
x_global_scales,
595595
num_max_dispatch_tokens_per_rank, num_experts,
596596
use_fp8, round_scale, use_ue8m0,
597-
use_nvfp4, use_ue8m0_for_sf,
597+
use_nvfp4, use_ue8m0_for_nvfp4_x_scale,
598598
async_finish, return_recv_hook)
599599
handle = (packed_recv_src_info, packed_recv_layout_range, num_max_dispatch_tokens_per_rank, x.size(1), num_experts)
600600
tensors_to_record = (x, topk_idx,
601601
packed_recv_x, packed_recv_x_scales, packed_recv_count,
602602
packed_recv_src_info, packed_recv_layout_range,
603603
cumulative_local_expert_recv_stats,
604-
x_sf_scale)
604+
x_global_scales)
605605
if use_fp8 or use_nvfp4:
606606
packed_recv_x = (packed_recv_x, packed_recv_x_scales)
607607
return packed_recv_x, packed_recv_count, handle, \

tests/test_low_latency.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,21 +54,21 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
5454
for dispatch_data_type in ('bf16', 'fp8', 'nvfp4'):
5555
dispatch_use_fp8 = dispatch_data_type == 'fp8'
5656
dispatch_use_nvfp4 = dispatch_data_type == 'nvfp4'
57-
use_ue8m0_for_sf = False
57+
use_ue8m0_for_nvfp4_x_scale = False
5858
for round_scale in (False, True) if dispatch_use_fp8 else (False, ):
5959
for use_ue8m0 in (False, True) if round_scale else (False, ):
6060
num_times += 1
6161
for i in range((num_times % 2) + 1):
6262
cumulative_local_expert_recv_stats = torch.zeros((num_local_experts, ), dtype=torch.int, device='cuda')
6363
x_max = torch.max(torch.abs(current_x))
64-
x_sf_scale = (MAX_E4M3 * MAX_NVFP4) / x_max.to(torch.float32)
65-
dist.all_reduce(x_sf_scale, op=dist.ReduceOp.MIN, group=group)
64+
x_global_scales = (MAX_E4M3 * MAX_NVFP4) / x_max.to(torch.float32)
65+
dist.all_reduce(x_global_scales, op=dist.ReduceOp.MIN, group=group)
6666
packed_recv_x, packed_recv_count, handle, event, hook = \
6767
buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts,
6868
use_fp8=dispatch_use_fp8, round_scale=round_scale, use_ue8m0=use_ue8m0,
69-
use_nvfp4=dispatch_use_nvfp4, use_ue8m0_for_sf=use_ue8m0_for_sf,
69+
use_nvfp4=dispatch_use_nvfp4, use_ue8m0_for_nvfp4_x_scale=use_ue8m0_for_nvfp4_x_scale,
7070
cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats,
71-
x_sf_scale=x_sf_scale,
71+
x_global_scales=x_global_scales,
7272
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)
7373
hook() if return_recv_hook else event.current_stream_wait()
7474
if dispatch_use_fp8:
@@ -77,9 +77,10 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
7777
elif dispatch_use_nvfp4:
7878
recv_x_scale_view = packed_recv_x[1]
7979
recv_x_scale_view = recv_x_scale_view.permute(5, 2, 0, 1, 4, 3)
80+
print(f'for num_times: {num_times}, recv_x_scale_view.shape: {recv_x_scale_view.shape}')
8081
recv_x_scale_view = recv_x_scale_view.contiguous().view(num_local_experts, int(num_ranks * num_tokens), hidden // 16)
8182
packed_recv_x = (packed_recv_x[0], recv_x_scale_view)
82-
simulated_gemm_x = per_token_cast_back(packed_recv_x[0], packed_recv_x[1], x_sf_scale, use_ue8m0_for_sf=use_ue8m0_for_sf, src_data_format='nvfp4')
83+
simulated_gemm_x = per_token_cast_back(packed_recv_x[0], packed_recv_x[1], x_global_scales, use_ue8m0_for_nvfp4_x_scale=use_ue8m0_for_nvfp4_x_scale, src_data_format='nvfp4')
8384
else:
8485
packed_recv_x = packed_recv_x
8586
simulated_gemm_x = packed_recv_x.clone()
@@ -157,7 +158,7 @@ def test_func(return_recv_hook: bool):
157158
recv_x, recv_count, handle, event, hook = \
158159
buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts,
159160
cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats,
160-
use_fp8=False, use_nvfp4=True, x_sf_scale=x_sf_scale,
161+
use_fp8=False, use_nvfp4=True, x_global_scales=x_global_scales,
161162
async_finish=False, return_recv_hook=return_recv_hook)
162163
large_gemm_with_hook(hook) if return_recv_hook else None
163164
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle,

0 commit comments

Comments
 (0)