Skip to content

Commit 74b631a

Browse files
committed
change layout of dispatch output x_scales
1 parent 5cd59de commit 74b631a

File tree

3 files changed

+44
-9
lines changed

3 files changed

+44
-9
lines changed

csrc/deep_ep.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1159,12 +1159,26 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
11591159
if (not use_ue8m0) {
11601160
packed_recv_x_scales = torch::empty({num_local_experts, hidden / 128, num_ranks * num_max_dispatch_tokens_per_rank},
11611161
torch::dtype(torch::kFloat32).device(torch::kCUDA));
1162+
packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2);
11621163
} else {
11631164
EP_HOST_ASSERT(round_scale);
1164-
packed_recv_x_scales = torch::empty({num_local_experts, hidden / 512, num_ranks * num_max_dispatch_tokens_per_rank},
1165-
torch::dtype(torch::kInt).device(torch::kCUDA));
1165+
// The blockscale tensor in FP8-E4M3, with shape (32, 4, rm, 4, rk, l)
1166+
// but the physical layout is (l, rm, rk, 32, 4, 4).
1167+
// For the shape of output_scales, `32 * 4 * rm` is a padded m to nearest multiple of 128.
1168+
// `4 * rk` is a padded `k // 16` to nearest multiple of 4. These layout constants are
1169+
// required by the NVIDIA Blackwell MMA operations.
1170+
// So we need to allocate the tensor with shape (l, rm, rk, 32, 4, 4) and transpose it to the logical layout.
1171+
auto rm = (hidden + 127) / 128;
1172+
auto rk = (hidden + 15) / 16;
1173+
auto l = num_local_experts;
1174+
packed_recv_x_scales = torch::empty({l, rm, rk, 32, 4, 4},
1175+
torch::dtype(torch::kFloat8_e4m3fn).device(torch::kCUDA));
1176+
// Reshape and permute the packed_recv_x_scales tensor to match the required layout
1177+
packed_recv_x_scales = packed_recv_x_scales.value().view(torch::kFloat8_e4m3fn).view(
1178+
{l, rm, rk, 32, 4, 4}
1179+
);
1180+
packed_recv_x_scales = packed_recv_x_scales.value().permute({3, 4, 1, 5, 2, 0});
11661181
}
1167-
packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2);
11681182
packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr();
11691183
}else if (use_nvfp4) {
11701184
constexpr int SF_VEC_SIZE = 16;

csrc/kernels/internode_ll.cu

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -517,18 +517,27 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, void* packed_recv_x_sf
517517
}
518518
} else if constexpr (kUseNVFP4) {
519519
// Equivalent CuTe layout:
520-
// (num_tokens, (num_packed, num_elems_per_pack)):(num_elems_per_pack, (num_tokens * num_elems_per_pack, 1))
521520
const auto src_scales = reinterpret_cast<uint8_t*>(reinterpret_cast<uint8_t*>(src_data) + hidden_bytes);
522521
const auto num_elems_per_pack = static_cast<int>(sizeof(packed_t) / sizeof(scale_t));
522+
const auto max_rk = (hidden_int4 * 4 + 15) / 16;
523523
const auto token_idx = recv_token_begin_idx + i;
524-
const auto token_stride = num_elems_per_pack;
525-
const auto pack_stride = num_ranks * num_max_dispatch_tokens_per_rank * num_elems_per_pack;
524+
525+
const int stride_dim_0 = 1;
526+
const int stride_dim_1 = num_elems_per_pack;
527+
const int stride_dim_2 = 4*4;
528+
const int stride_rk = 32*4*4;
529+
const int stride_rm = max_rk*32*4*4;
530+
531+
const auto index_rm = token_idx / (32 * 4);
526532
#pragma unroll
527533
for (int j = lane_id; j < num_scales; j += 32) {
528-
const auto pack_idx = j / num_elems_per_pack;
529-
const auto elem_idx = j % num_elems_per_pack;
534+
auto index_rk = j / num_elems_per_pack;
535+
auto index_dim_2 = (token_idx % (32)) / 4;
536+
auto index_dim_1 = token_idx % (32 * 4);
537+
auto index_dim_0 = j % num_elems_per_pack;
538+
auto offset = index_dim_0 * stride_dim_0 + index_dim_1 * stride_dim_1 + index_dim_2 * stride_dim_2 + index_rk * stride_rk + index_rm * stride_rm;
530539
auto scale = ld_nc_global(src_scales + j);
531-
recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;
540+
recv_x_scales[offset] = scale;
532541
}
533542
}
534543
}

tests/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ def per_token_cast_to_fp8(x: torch.Tensor):
5252

5353

5454
def cast_fp8_to_fp32(x_fp8: torch.Tensor, x_scales: torch.Tensor):
55+
# TODO(shifangx): remove print after debugging
56+
print(f"in cast_fp8_to_fp32, x_fp8.shape: {x_fp8.shape}, x_scales.shape: {x_scales.shape}")
5557
if x_fp8.numel() == 0:
5658
return x_fp8.to(torch.bfloat16)
5759
if x_scales.dtype == torch.int:
@@ -91,11 +93,21 @@ def int32_to_8floats_lookup(tensor: torch.Tensor, table: torch.Tensor) -> torch.
9193

9294

9395
def cast_nvfp4_to_fp32(x_nvfp4: torch.Tensor, x_scales: torch.Tensor, x_sf_scale: float, use_ue8m0_for_nvfp4_sf: bool = False):
96+
# TODO(shifangx): remove print after debugging
97+
print(f"in cast_nvfp4_to_fp32, x_nvfp4.shape: {x_nvfp4.shape}, x_scales.shape: {x_scales.shape}")
9498
NVFP4_TABLE = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1.0, -1.5, -2, -3, -4, -6], dtype=torch.float32, device='cuda')
9599
if use_ue8m0_for_nvfp4_sf:
96100
x_scales = x_scales.view(dtype=torch.int8).to(torch.int) << 23
97101
x_scales = x_scales.view(dtype=torch.float)
98102
else:
103+
# shape of x_scales: (32, 4, rm, 4, rk, l)
104+
dim_0, dim_1, dim_2, dim_3, dim_4, dim_5 = x_scales.shape
105+
assert dim_0 == 32 and dim_1 == 4 and dim_3 == 4 , "x_scales must be in the shape of (32, 4, rm, 4, rk, l)"
106+
rm = dim_2
107+
rk = dim_4
108+
l = dim_5
109+
x_scales = x_scales.view(dtype=torch.float8_e4m3fn).permute({{5, 2, 0, 1, 4, 3}}); # shape of x_scales: (l, rm, 32, 4, rk, 4)
110+
x_scales = x_scales.reshape({l, rm * 32 * 4, rk * 4}); # shape of x_scales: (l, m, k)
99111
x_scales = x_scales.view(dtype=torch.float8_e4m3fn).to(torch.float32)
100112
x_sf_scale = x_sf_scale.view(*x_sf_scale.shape, 1)
101113
x_scales = x_scales * (1 / x_sf_scale)

0 commit comments

Comments
 (0)