@@ -70,21 +70,13 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
7070 x_sf_scale = x_sf_scale ,
7171 async_finish = not return_recv_hook , return_recv_hook = return_recv_hook )
7272 hook () if return_recv_hook else event .current_stream_wait ()
73- if dispatch_use_nvfp4 :
74- print (f"rank { rank } , num_times { num_times } , i: { i } , current_x: { current_x } , topk_idx: { topk_idx } , packed_recv_x: { packed_recv_x } " )
7573 if dispatch_use_fp8 :
7674 packed_recv_x = (packed_recv_x [0 ], packed_recv_x [1 ].contiguous ())
7775 elif dispatch_use_nvfp4 :
78- recv_x_scale_packed = packed_recv_x [1 ].clone ()
79- recv_x_scale_view = recv_x_scale_packed .clone ()
80- print (f"rank { rank } , num_times { num_times } , i: { i } , recv_x_scale_packed.shape:{ recv_x_scale_packed .shape } , recv_x_scale_packed.dtype: { recv_x_scale_packed .dtype } " )
76+ recv_x_scale_view = packed_recv_x [1 ].clone ()
8177 recv_x_scale_view = recv_x_scale_view .permute (5 , 2 , 0 , 1 , 4 , 3 )
82- print (f"rank { rank } , num_times { num_times } , i: { i } , after permute, recv_x_scale_view.shape: { recv_x_scale_view .shape } , recv_x_scale_view.dtype: { recv_x_scale_view .dtype } " )
8378 recv_x_scale_view = recv_x_scale_view .contiguous ().view (torch .int32 )
84- print (f"rank { rank } , num_times { num_times } , i: { i } , after view to change dtype, recv_x_scale_view.shape: { recv_x_scale_view .shape } , recv_x_scale_view.dtype: { recv_x_scale_view .dtype } " )
8579 recv_x_scale_view = recv_x_scale_view .contiguous ().view (num_local_experts , int (num_ranks * num_tokens ), hidden // (16 * 4 ))
86- print (f"rank { rank } , num_times { num_times } , i: { i } , after view to change shape, recv_x_scale_view.shape: { recv_x_scale_view .shape } , recv_x_scale_view.dtype: { recv_x_scale_view .dtype } " )
87- print (f"rank { rank } , num_times { num_times } , i: { i } , recv_x_scale_packed.shape:{ recv_x_scale_packed .shape } , recv_x_scale_packed.dtype: { recv_x_scale_packed .dtype } , recv_x_scale_view.shape: { recv_x_scale_view .shape } , recv_x_scale_view.dtype: { recv_x_scale_view .dtype } , recv_x_scale_view: { recv_x_scale_view } " )
8880 packed_recv_x = (packed_recv_x [0 ], recv_x_scale_view , packed_recv_x [2 ].contiguous ())
8981 else :
9082 packed_recv_x = packed_recv_x
@@ -122,12 +114,9 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
122114 recv_x_amin = recv_x [:, :- 128 ].amin (dim = - 1 )
123115 recv_x_amax = recv_x [:, :- 128 ].amax (dim = - 1 )
124116 recv_src_info = recv_src_info [:num_valid_tokens ]
125- if dispatch_use_nvfp4 :
126- print (f"rank { rank } , num_times { num_times } , expert_id: { expert_id } , recv_x: { recv_x } , recv_x_amin:{ recv_x_amin } , recv_x_amax:{ recv_x_amax } , recv_x[:, -1]: { recv_x [:, - 1 ]} , recv_src_info.view(-1): { recv_src_info .view (- 1 )} " )
127117 assert torch .equal (recv_x_amin , recv_x_amax ), f'recv_x_amin: { recv_x_amin } , recv_x_amax: { recv_x_amax } '
128118 diff = calc_diff (recv_x [:, - 1 ], recv_src_info .view (- 1 ))
129119 if dispatch_use_nvfp4 :
130- print (f"rank { rank } , num_times { num_times } , expert_id: { expert_id } , diff after dispatch: { diff } " )
131120 assert diff < 0.007 , f"rank { rank } , num_times { num_times } , expert_id: { expert_id } , diff: { diff } "
132121 elif round_scale :
133122 assert diff < 0.007 , f"rank { rank } , num_times { num_times } , expert_id: { expert_id } , diff: { diff } "
@@ -156,8 +145,6 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
156145 hook () if return_recv_hook else event .current_stream_wait ()
157146 if do_check :
158147 diff = calc_diff (current_x * topk_weights .masked_fill (topk_idx == - 1 , 0 ).sum (dim = 1 ).view (- 1 , 1 ), combined_x )
159- if dispatch_use_nvfp4 :
160- print (f"rank { rank } , num_times { num_times } , diff after combine: { diff } " )
161148 assert torch .isnan (combined_x ).sum ().item () == 0
162149 assert diff < (1 if (dispatch_use_fp8 or dispatch_use_nvfp4 ) else 1e-5 ), f'Error: { diff = } , { dispatch_use_fp8 = } , { dispatch_use_nvfp4 = } , { zero_copy = } '
163150 hash_value ^= hash_tensor (combined_x )
0 commit comments