Skip to content

Commit cb1757a

Browse files
committed
change name from use_ue8m0_for_nvfp4_sf to use_ue8m0_for_sf
1 parent 0cfe452 commit cb1757a

File tree

7 files changed

+26
-17
lines changed

7 files changed

+26
-17
lines changed

csrc/deep_ep.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1094,7 +1094,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
10941094
const std::optional<torch::Tensor>& x_sf_scale,
10951095
int num_max_dispatch_tokens_per_rank, int num_experts,
10961096
bool use_fp8, bool round_scale, bool use_ue8m0,
1097-
bool use_nvfp4, bool use_ue8m0_for_nvfp4_sf,
1097+
bool use_nvfp4, bool use_ue8m0_for_sf,
10981098
bool async, bool return_recv_hook) {
10991099
#ifndef DISABLE_NVSHMEM
11001100
EP_HOST_ASSERT(low_latency_mode);
@@ -1200,7 +1200,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
12001200
num_tokens, hidden, num_max_dispatch_tokens_per_rank,
12011201
num_topk, num_experts, rank, num_ranks,
12021202
use_fp8, round_scale, use_ue8m0,
1203-
use_nvfp4, use_ue8m0_for_nvfp4_sf,
1203+
use_nvfp4, use_ue8m0_for_sf,
12041204
workspace, num_device_sms,
12051205
launch_stream, phases);
12061206
};

csrc/deep_ep.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ struct Buffer {
150150
const std::optional<torch::Tensor>& x_sf_scale,
151151
int num_max_dispatch_tokens_per_rank, int num_experts,
152152
bool use_fp8, bool round_scale, bool use_ue8m0,
153-
bool use_nvfp4, bool use_ue8m0_for_nvfp4_sf,
153+
bool use_nvfp4, bool use_ue8m0_for_sf,
154154
bool async, bool return_recv_hook);
155155

156156
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>

csrc/kernels/api.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
151151
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
152152
int num_topk, int num_experts, int rank, int num_ranks,
153153
bool use_fp8, bool round_scale, bool use_ue8m0,
154-
bool use_nvfp4, bool use_ue8m0_for_nvfp4_sf,
154+
bool use_nvfp4, bool use_ue8m0_for_sf,
155155
void* workspace, int num_device_sms,
156156
cudaStream_t stream, int phases);
157157

csrc/kernels/internode_ll.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
545545
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
546546
int num_topk, int num_experts, int rank, int num_ranks,
547547
bool use_fp8, bool round_scale, bool use_ue8m0,
548-
bool use_nvfp4, bool use_ue8m0_for_nvfp4_sf,
548+
bool use_nvfp4, bool use_ue8m0_for_sf,
549549
void* workspace, int num_device_sms,
550550
cudaStream_t stream, int phases) {
551551
constexpr int kNumMaxTopK = 9;
@@ -573,9 +573,9 @@ if (use_fp8 and not use_ue8m0) \
573573
dispatch_func = dispatch<true, false, false, false, hidden>; \
574574
if (use_fp8 and use_ue8m0) \
575575
dispatch_func = dispatch<true, true, false, false, hidden>; \
576-
if (use_nvfp4 and not use_ue8m0_for_nvfp4_sf) \
576+
if (use_nvfp4 and not use_ue8m0_for_sf) \
577577
dispatch_func = dispatch<false, false, true, false, hidden>; \
578-
if (use_nvfp4 and use_ue8m0_for_nvfp4_sf) \
578+
if (use_nvfp4 and use_ue8m0_for_sf) \
579579
dispatch_func = dispatch<false, false, true, true, hidden>; \
580580
LAUNCH_KERNEL(&cfg, dispatch_func, \
581581
packed_recv_x, packed_recv_x_scales, \

deep_ep/buffer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,7 @@ def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
530530
dispatch_wait_recv_cost_stats: Optional[torch.Tensor] = None,
531531
x_sf_scale: Optional[torch.Tensor] = None,
532532
use_fp8: bool = True, round_scale: bool = False, use_ue8m0: bool = False,
533-
use_nvfp4: bool = False, use_ue8m0_for_nvfp4_sf: bool = False,
533+
use_nvfp4: bool = False, use_ue8m0_for_sf: bool = False,
534534
async_finish: bool = False, return_recv_hook: bool = False) -> \
535535
Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, Tuple, EventOverlap, Callable]:
536536
"""
@@ -558,7 +558,7 @@ def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
558558
round_scale: whether round the scaling factors into power of 2.
559559
use_ue8m0: whether use UE8M0 as scaling factor format (available only with `round_scale=True`).
560560
use_nvfp4: whether to enable NVFP4 casting, with this, the received data will be a tuple of NVFP4 tensor and scaling factors.
561-
use_ue8m0_for_nvfp4_sf: whether use UE8M0 as NVFP4 scaling factor format (available only with `use_nvfp4=True`).
561+
use_ue8m0_for_sf: whether use UE8M0 as NVFP4 scaling factor format (available only with `use_nvfp4=True`).
562562
async_finish: the current stream will not wait for the communication kernels to be finished if set.
563563
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
564564
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
@@ -590,7 +590,7 @@ def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
590590
x_sf_scale,
591591
num_max_dispatch_tokens_per_rank, num_experts,
592592
use_fp8, round_scale, use_ue8m0,
593-
use_nvfp4, use_ue8m0_for_nvfp4_sf,
593+
use_nvfp4, use_ue8m0_for_sf,
594594
async_finish, return_recv_hook)
595595
handle = (packed_recv_src_info, packed_recv_layout_range, num_max_dispatch_tokens_per_rank, x.size(1), num_experts)
596596
tensors_to_record = (x, topk_idx,

tests/test_low_latency.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from typing import Optional
1010

1111
import deep_ep
12-
from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, cast_fp8_to_fp32, cast_nvfp4_to_fp32
12+
from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_back
1313

1414
MAX_E4M3 = 448
1515
MAX_NVFP4 = 6.0
@@ -54,7 +54,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
5454
for dispatch_data_type in ('bf16', 'fp8', 'nvfp4'):
5555
dispatch_use_fp8 = dispatch_data_type == 'fp8'
5656
dispatch_use_nvfp4 = dispatch_data_type == 'nvfp4'
57-
use_ue8m0_for_nvfp4_sf = False
57+
use_ue8m0_for_sf = False
5858
for round_scale in (False, True) if dispatch_use_fp8 else (False, ):
5959
for use_ue8m0 in (False, True) if round_scale else (False, ):
6060
num_times += 1
@@ -66,20 +66,20 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
6666
packed_recv_x, packed_recv_count, handle, event, hook = \
6767
buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts,
6868
use_fp8=dispatch_use_fp8, round_scale=round_scale, use_ue8m0=use_ue8m0,
69-
use_nvfp4=dispatch_use_nvfp4, use_ue8m0_for_nvfp4_sf=use_ue8m0_for_nvfp4_sf,
69+
use_nvfp4=dispatch_use_nvfp4, use_ue8m0_for_sf=use_ue8m0_for_sf,
7070
cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats,
7171
x_sf_scale=x_sf_scale,
7272
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)
7373
hook() if return_recv_hook else event.current_stream_wait()
7474
if dispatch_use_fp8:
7575
packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous())
76-
simulated_gemm_x = cast_fp8_to_fp32(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, hidden // 128)).view(packed_recv_x[0].shape)
76+
simulated_gemm_x = per_token_cast_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, hidden // 128)).view(packed_recv_x[0].shape)
7777
elif dispatch_use_nvfp4:
7878
recv_x_scale_view = packed_recv_x[1]
7979
recv_x_scale_view = recv_x_scale_view.permute(5, 2, 0, 1, 4, 3)
8080
recv_x_scale_view = recv_x_scale_view.contiguous().view(num_local_experts, int(num_ranks * num_tokens), hidden // 16)
8181
packed_recv_x = (packed_recv_x[0], recv_x_scale_view)
82-
simulated_gemm_x = cast_nvfp4_to_fp32(packed_recv_x[0], packed_recv_x[1], x_sf_scale, use_ue8m0_for_nvfp4_sf=use_ue8m0_for_nvfp4_sf)
82+
simulated_gemm_x = per_token_cast_back(packed_recv_x[0], packed_recv_x[1], x_sf_scale, use_ue8m0_for_sf=use_ue8m0_for_sf, src_data_format='nvfp4')
8383
else:
8484
packed_recv_x = packed_recv_x
8585
simulated_gemm_x = packed_recv_x.clone()

tests/utils.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,10 @@ def int32_to_8floats_lookup(tensor: torch.Tensor, table: torch.Tensor) -> torch.
9090
return out
9191

9292

93-
def cast_nvfp4_to_fp32(x_nvfp4: torch.Tensor, x_scales: torch.Tensor, x_sf_scale: float, use_ue8m0_for_nvfp4_sf: bool = False):
93+
def cast_nvfp4_to_fp32(x_nvfp4: torch.Tensor, x_scales: torch.Tensor, x_sf_scale: float, use_ue8m0_for_sf: bool = False):
9494
assert(x_sf_scale.dim() == 0, f"expect x_sf_scale.dim() == 0, but got {x_sf_scale.dim()}")
9595
NVFP4_TABLE = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1.0, -1.5, -2, -3, -4, -6], dtype=torch.float32, device='cuda')
96-
if use_ue8m0_for_nvfp4_sf:
96+
if use_ue8m0_for_sf:
9797
x_scales = x_scales.view(dtype=torch.int8).to(torch.int) << 23
9898
x_scales = x_scales.view(dtype=torch.float)
9999
else:
@@ -111,6 +111,15 @@ def cast_nvfp4_to_fp32(x_nvfp4: torch.Tensor, x_scales: torch.Tensor, x_sf_scale
111111
return x_fp32
112112

113113

114+
def per_token_cast_back(x: torch.Tensor, x_scales: torch.Tensor, x_sf_scale: torch.Tensor = None, use_ue8m0_for_sf: bool = False, src_data_format: str = 'fp8'):
115+
if src_data_format == 'fp8':
116+
return cast_fp8_to_fp32(x, x_scales)
117+
elif src_data_format == 'nvfp4':
118+
return cast_nvfp4_to_fp32(x, x_scales, x_sf_scale, use_ue8m0_for_sf)
119+
else:
120+
raise ValueError(f"Unsupported src_data_format: {src_data_format}")
121+
122+
114123
def inplace_unique(x: torch.Tensor, num_slots: int):
115124
assert x.dim() == 2
116125
mask = x < 0

0 commit comments

Comments
 (0)