Skip to content

Commit 78a07e0

Browse files
committed
move permute from test functino into dispatch cpp function
1 parent dd9374c commit 78a07e0

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

csrc/deep_ep.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1177,8 +1177,12 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
11771177
auto m = num_ranks * num_max_dispatch_tokens_per_rank;
11781178
auto rm = (m + 127) / 128;
11791179
auto rk = hidden / (kNumPerChannels * NUM_SF_ELEMS_PER_PACK);
1180+
// The physical layout is (l, rm, rk, 32, 4, 4).
11801181
packed_recv_x_scales = torch::empty({l, rm, rk, 32, 4, 4},
11811182
torch::dtype(torch::kInt8).device(torch::kCUDA));
1183+
// After permute, the logical shape is (32, 4, rm, 4, rk, l)
1184+
packed_recv_x_scales = packed_recv_x_scales.value().permute({3, 4, 1, 5, 2, 0});
1185+
11821186
packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr();
11831187
packed_recv_x_sf_scale = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kFloat32).device(torch::kCUDA));
11841188
packed_recv_x_sf_scale_ptr = packed_recv_x_sf_scale->data_ptr();

tests/test_low_latency.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,18 +75,15 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
7575
if dispatch_use_fp8:
7676
packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous())
7777
elif dispatch_use_nvfp4:
78-
recv_x_scale_packed = packed_recv_x[1].clone().contiguous()
78+
recv_x_scale_packed = packed_recv_x[1].clone()
7979
recv_x_scale_view = recv_x_scale_packed.clone()
8080
print(f"rank {rank}, num_times {num_times}, i: {i}, recv_x_scale_packed.shape:{recv_x_scale_packed.shape}, recv_x_scale_packed.dtype: {recv_x_scale_packed.dtype}")
81-
recv_x_scale_view = recv_x_scale_view.contiguous().view(num_local_experts, int(num_ranks * num_tokens) // 128, hidden // (16 * 4), 32, 4, 4)
82-
recv_x_scale_view = recv_x_scale_view.permute(3, 4, 1, 5, 2, 0)
83-
print(f"rank {rank}, num_times {num_times}, i: {i}, after first permute, recv_x_scale_view.shape: {recv_x_scale_view.shape}, recv_x_scale_view.dtype: {recv_x_scale_view.dtype}")
8481
recv_x_scale_view = recv_x_scale_view.permute(5, 2, 0, 1, 4, 3)
85-
print(f"rank {rank}, num_times {num_times}, i: {i}, after second permute, recv_x_scale_view.shape: {recv_x_scale_view.shape}, recv_x_scale_view.dtype: {recv_x_scale_view.dtype}")
86-
recv_x_scale_view = recv_x_scale_view.view(torch.int32)
87-
print(f"rank {rank}, num_times {num_times}, i: {i}, after view change dtype, recv_x_scale_view.shape: {recv_x_scale_view.shape}, recv_x_scale_view.dtype: {recv_x_scale_view.dtype}")
82+
print(f"rank {rank}, num_times {num_times}, i: {i}, after permute, recv_x_scale_view.shape: {recv_x_scale_view.shape}, recv_x_scale_view.dtype: {recv_x_scale_view.dtype}")
83+
recv_x_scale_view = recv_x_scale_view.contiguous().view(torch.int32)
84+
print(f"rank {rank}, num_times {num_times}, i: {i}, after view to change dtype, recv_x_scale_view.shape: {recv_x_scale_view.shape}, recv_x_scale_view.dtype: {recv_x_scale_view.dtype}")
8885
recv_x_scale_view = recv_x_scale_view.contiguous().view(num_local_experts, int(num_ranks * num_tokens), hidden // (16 * 4))
89-
print(f"rank {rank}, num_times {num_times}, i: {i}, after view change shape, recv_x_scale_view.shape: {recv_x_scale_view.shape}, recv_x_scale_view.dtype: {recv_x_scale_view.dtype}")
86+
print(f"rank {rank}, num_times {num_times}, i: {i}, after view to change shape, recv_x_scale_view.shape: {recv_x_scale_view.shape}, recv_x_scale_view.dtype: {recv_x_scale_view.dtype}")
9087
print(f"rank {rank}, num_times {num_times}, i: {i}, recv_x_scale_packed.shape:{recv_x_scale_packed.shape}, recv_x_scale_packed.dtype: {recv_x_scale_packed.dtype}, recv_x_scale_view.shape: {recv_x_scale_view.shape}, recv_x_scale_view.dtype: {recv_x_scale_view.dtype}, recv_x_scale_view: {recv_x_scale_view}")
9188
packed_recv_x = (packed_recv_x[0], recv_x_scale_view, packed_recv_x[2].contiguous())
9289
else:

0 commit comments

Comments
 (0)