Skip to content

Commit 57a670a

Browse files
committed
fix issue with padding m
1 parent 82147f2 commit 57a670a

File tree

3 files changed

+3
-4
lines changed

3 files changed

+3
-4
lines changed

csrc/kernels/internode_ll.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,8 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
450450
const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
451451
const auto recv_range = packed_recv_layout_range + local_expert_idx * num_ranks;
452452
const auto num_aligned_scales = align<int>(num_scales, sizeof(float) / sizeof(scale_t));
453-
const auto recv_x_scales = static_cast<scale_t*>(packed_recv_x_scales) + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_aligned_scales;
453+
const auto num_aligned_tokens = align<int>(num_ranks * num_max_dispatch_tokens_per_rank, 128);
454+
const auto recv_x_scales = static_cast<scale_t*>(packed_recv_x_scales) + local_expert_idx * num_aligned_tokens * num_aligned_scales;
454455

455456
// Shared between sub-warps in warp groups
456457
__shared__ int shared_num_recv_tokens[kNumMaxWarpGroups], shared_recv_token_begin_idx[kNumMaxWarpGroups];

tests/test_low_latency.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
7777
elif dispatch_use_nvfp4:
7878
recv_x_scale_view = packed_recv_x[1]
7979
recv_x_scale_view = recv_x_scale_view.permute(5, 2, 0, 1, 4, 3)
80-
print(f'for num_times: {num_times}, recv_x_scale_view.shape: {recv_x_scale_view.shape}')
8180
recv_x_scale_view = recv_x_scale_view.contiguous().view(num_local_experts, int(num_ranks * num_tokens), hidden // 16)
8281
packed_recv_x = (packed_recv_x[0], recv_x_scale_view)
8382
simulated_gemm_x = per_token_cast_back(packed_recv_x[0], packed_recv_x[1], x_global_scales, use_ue8m0_for_nvfp4_x_scale=use_ue8m0_for_nvfp4_x_scale, src_data_format='nvfp4')

tests/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,7 @@ def int32_to_8floats_lookup(tensor: torch.Tensor, table: torch.Tensor) -> torch.
9090
return out
9191

9292

93-
def cast_nvfp4_to_fp32(x_nvfp4: torch.Tensor, x_scales: torch.Tensor, x_global_scales: float, use_ue8m0_for_nvfp4_x_scale: bool = False):
94-
assert(x_global_scales.dim() == 0, f"expect x_global_scales.dim() == 0, but got {x_global_scales.dim()}")
93+
def cast_nvfp4_to_fp32(x_nvfp4: torch.Tensor, x_scales: torch.Tensor, x_global_scales: torch.Tensor, use_ue8m0_for_nvfp4_x_scale: bool = False):
9594
NVFP4_TABLE = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1.0, -1.5, -2, -3, -4, -6], dtype=torch.float32, device='cuda')
9695
if use_ue8m0_for_nvfp4_x_scale:
9796
x_scales = x_scales.view(dtype=torch.int8).to(torch.int) << 23

0 commit comments

Comments
 (0)