@@ -75,18 +75,15 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
7575 if dispatch_use_fp8 :
7676 packed_recv_x = (packed_recv_x [0 ], packed_recv_x [1 ].contiguous ())
7777 elif dispatch_use_nvfp4 :
78- recv_x_scale_packed = packed_recv_x [1 ].clone (). contiguous ()
78+ recv_x_scale_packed = packed_recv_x [1 ].clone ()
7979 recv_x_scale_view = recv_x_scale_packed .clone ()
8080 print (f"rank { rank } , num_times { num_times } , i: { i } , recv_x_scale_packed.shape:{ recv_x_scale_packed .shape } , recv_x_scale_packed.dtype: { recv_x_scale_packed .dtype } " )
81- recv_x_scale_view = recv_x_scale_view .contiguous ().view (num_local_experts , int (num_ranks * num_tokens ) // 128 , hidden // (16 * 4 ), 32 , 4 , 4 )
82- recv_x_scale_view = recv_x_scale_view .permute (3 , 4 , 1 , 5 , 2 , 0 )
83- print (f"rank { rank } , num_times { num_times } , i: { i } , after first permute, recv_x_scale_view.shape: { recv_x_scale_view .shape } , recv_x_scale_view.dtype: { recv_x_scale_view .dtype } " )
8481 recv_x_scale_view = recv_x_scale_view .permute (5 , 2 , 0 , 1 , 4 , 3 )
85- print (f"rank { rank } , num_times { num_times } , i: { i } , after second permute, recv_x_scale_view.shape: { recv_x_scale_view .shape } , recv_x_scale_view.dtype: { recv_x_scale_view .dtype } " )
86- recv_x_scale_view = recv_x_scale_view .view (torch .int32 )
87- print (f"rank { rank } , num_times { num_times } , i: { i } , after view change dtype, recv_x_scale_view.shape: { recv_x_scale_view .shape } , recv_x_scale_view.dtype: { recv_x_scale_view .dtype } " )
82+ print (f"rank { rank } , num_times { num_times } , i: { i } , after permute, recv_x_scale_view.shape: { recv_x_scale_view .shape } , recv_x_scale_view.dtype: { recv_x_scale_view .dtype } " )
83+ recv_x_scale_view = recv_x_scale_view .contiguous (). view (torch .int32 )
84+ print (f"rank { rank } , num_times { num_times } , i: { i } , after view to change dtype, recv_x_scale_view.shape: { recv_x_scale_view .shape } , recv_x_scale_view.dtype: { recv_x_scale_view .dtype } " )
8885 recv_x_scale_view = recv_x_scale_view .contiguous ().view (num_local_experts , int (num_ranks * num_tokens ), hidden // (16 * 4 ))
89- print (f"rank { rank } , num_times { num_times } , i: { i } , after view change shape, recv_x_scale_view.shape: { recv_x_scale_view .shape } , recv_x_scale_view.dtype: { recv_x_scale_view .dtype } " )
86+ print (f"rank { rank } , num_times { num_times } , i: { i } , after view to change shape, recv_x_scale_view.shape: { recv_x_scale_view .shape } , recv_x_scale_view.dtype: { recv_x_scale_view .dtype } " )
9087 print (f"rank { rank } , num_times { num_times } , i: { i } , recv_x_scale_packed.shape:{ recv_x_scale_packed .shape } , recv_x_scale_packed.dtype: { recv_x_scale_packed .dtype } , recv_x_scale_view.shape: { recv_x_scale_view .shape } , recv_x_scale_view.dtype: { recv_x_scale_view .dtype } , recv_x_scale_view: { recv_x_scale_view } " )
9188 packed_recv_x = (packed_recv_x [0 ], recv_x_scale_view , packed_recv_x [2 ].contiguous ())
9289 else :
0 commit comments