9
9
from typing import Optional
10
10
11
11
import deep_ep
12
- from utils import init_dist , bench , bench_kineto , calc_diff , hash_tensor , cast_fp8_to_fp32 , cast_nvfp4_to_fp32
12
+ from utils import init_dist , bench , bench_kineto , calc_diff , hash_tensor , per_token_cast_back
13
13
14
14
MAX_E4M3 = 448
15
15
MAX_NVFP4 = 6.0
@@ -54,7 +54,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
54
54
for dispatch_data_type in ('bf16' , 'fp8' , 'nvfp4' ):
55
55
dispatch_use_fp8 = dispatch_data_type == 'fp8'
56
56
dispatch_use_nvfp4 = dispatch_data_type == 'nvfp4'
57
- use_ue8m0_for_nvfp4_sf = False
57
+ use_ue8m0_for_sf = False
58
58
for round_scale in (False , True ) if dispatch_use_fp8 else (False , ):
59
59
for use_ue8m0 in (False , True ) if round_scale else (False , ):
60
60
num_times += 1
@@ -66,20 +66,20 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
66
66
packed_recv_x , packed_recv_count , handle , event , hook = \
67
67
buffer .low_latency_dispatch (current_x , topk_idx , num_tokens , num_experts ,
68
68
use_fp8 = dispatch_use_fp8 , round_scale = round_scale , use_ue8m0 = use_ue8m0 ,
69
- use_nvfp4 = dispatch_use_nvfp4 , use_ue8m0_for_nvfp4_sf = use_ue8m0_for_nvfp4_sf ,
69
+ use_nvfp4 = dispatch_use_nvfp4 , use_ue8m0_for_sf = use_ue8m0_for_sf ,
70
70
cumulative_local_expert_recv_stats = cumulative_local_expert_recv_stats ,
71
71
x_sf_scale = x_sf_scale ,
72
72
async_finish = not return_recv_hook , return_recv_hook = return_recv_hook )
73
73
hook () if return_recv_hook else event .current_stream_wait ()
74
74
if dispatch_use_fp8 :
75
75
packed_recv_x = (packed_recv_x [0 ], packed_recv_x [1 ].contiguous ())
76
- simulated_gemm_x = cast_fp8_to_fp32 (packed_recv_x [0 ].view (- 1 , hidden ), packed_recv_x [1 ].view (- 1 , hidden // 128 )).view (packed_recv_x [0 ].shape )
76
+ simulated_gemm_x = per_token_cast_back (packed_recv_x [0 ].view (- 1 , hidden ), packed_recv_x [1 ].view (- 1 , hidden // 128 )).view (packed_recv_x [0 ].shape )
77
77
elif dispatch_use_nvfp4 :
78
78
recv_x_scale_view = packed_recv_x [1 ]
79
79
recv_x_scale_view = recv_x_scale_view .permute (5 , 2 , 0 , 1 , 4 , 3 )
80
80
recv_x_scale_view = recv_x_scale_view .contiguous ().view (num_local_experts , int (num_ranks * num_tokens ), hidden // 16 )
81
81
packed_recv_x = (packed_recv_x [0 ], recv_x_scale_view )
82
- simulated_gemm_x = cast_nvfp4_to_fp32 (packed_recv_x [0 ], packed_recv_x [1 ], x_sf_scale , use_ue8m0_for_nvfp4_sf = use_ue8m0_for_nvfp4_sf )
82
+ simulated_gemm_x = per_token_cast_back (packed_recv_x [0 ], packed_recv_x [1 ], x_sf_scale , use_ue8m0_for_sf = use_ue8m0_for_sf , src_data_format = 'nvfp4' )
83
83
else :
84
84
packed_recv_x = packed_recv_x
85
85
simulated_gemm_x = packed_recv_x .clone ()
0 commit comments