Skip to content

Commit ccf4eaf

Browse files
committed
change test case
1 parent 9d9e395 commit ccf4eaf

File tree

4 files changed

+14
-7
lines changed

4 files changed

+14
-7
lines changed

csrc/deep_ep.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1151,6 +1151,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
11511151
void* packed_recv_x_scales_ptr = nullptr;
11521152
EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4");
11531153

1154+
EP_HOST_ASSERT(not (use_fp8 and use_nvfp4));
11541155
if (use_fp8) {
11551156
// TODO: support unaligned cases
11561157
EP_HOST_ASSERT(hidden % 512 == 0);
@@ -1182,6 +1183,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
11821183
packed_recv_x_scales = packed_recv_x_scales.value().permute({3, 4, 1, 5, 2, 0});
11831184

11841185
packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr();
1186+
EP_HOST_ASSERT(packed_recv_x_scales_ptr != nullptr);
11851187
}
11861188

11871189
// Kernel launch

csrc/kernels/internode_ll.cu

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -269,10 +269,9 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
269269
thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0;
270270
float SFScaleVal = 1.0f;
271271
if constexpr (kUseNVFP4) {
272-
// Get scaling value: if x_sf_scale is nullptr, use 1.0f;
273-
if (x_sf_scale != nullptr) {
274-
SFScaleVal = *(static_cast<const float*>(x_sf_scale));
275-
}
272+
// Get scaling value;
273+
EP_DEVICE_ASSERT(x_sf_scale != nullptr);
274+
SFScaleVal = *(static_cast<const float*>(x_sf_scale));
276275
}
277276

278277
// FP8 or NVFP4 cast

deep_ep/buffer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -573,7 +573,11 @@ def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
573573
if `use_ue8m0=False`. With `use_ue8m0=True`, the second one is packed and shaped as
574574
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 512]` with type `torch.int`.
575575
Notice that, the last-two-dimension of the scaling tensors are in column-major for TMA compatibility.
576-
With `use_fp8=False`, the result would be a tensor shaped as
576+
with `use_nvfp4=True`: the first element is a `torch.Tensor` shaped as
577+
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 4]` with `torch.uint32`.
578+
The second tensor is the corresponding scales for the first element with shape
579+
`[32, 4, num_max_dispatch_tokens_per_rank * num_ranks // 128, 4, hidden // 64, num_local_experts]` with `torch.uint8`.
580+
With `use_fp8=False and use_nvfp4=False`, the result would be a tensor shaped as
577581
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`.
578582
Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are,
579583
as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph if synced).

tests/test_low_latency.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
6161
for i in range((num_times % 2) + 1):
6262
cumulative_local_expert_recv_stats = torch.zeros((num_local_experts, ), dtype=torch.int, device='cuda')
6363
x_max = torch.max(torch.abs(current_x))
64-
dist.all_reduce(x_max, op=dist.ReduceOp.MAX, group=group)
6564
x_sf_scale = (MAX_E4M3 * MAX_NVFP4) / x_max.to(torch.float32)
65+
dist.all_reduce(x_sf_scale, op=dist.ReduceOp.MIN, group=group)
6666
packed_recv_x, packed_recv_count, handle, event, hook = \
6767
buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts,
6868
use_fp8=dispatch_use_fp8, round_scale=round_scale, use_ue8m0=use_ue8m0,
@@ -153,10 +153,12 @@ def large_gemm_with_hook(hook):
153153

154154
# noinspection PyShadowingNames
155155
def test_func(return_recv_hook: bool):
156+
# NOTE: use nvfp4
156157
recv_x, recv_count, handle, event, hook = \
157158
buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts,
158159
cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats,
159-
use_fp8=True, async_finish=False, return_recv_hook=return_recv_hook)
160+
use_fp8=False, use_nvfp4=True, x_sf_scale=x_sf_scale,
161+
async_finish=False, return_recv_hook=return_recv_hook)
160162
large_gemm_with_hook(hook) if return_recv_hook else None
161163
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle,
162164
use_logfmt=use_logfmt, return_recv_hook=return_recv_hook)

0 commit comments

Comments
 (0)