Skip to content

Commit c358fd5

Browse files
committed
code clean
1 parent 2e6ef2b commit c358fd5

File tree

3 files changed

+14
-20
lines changed

3 files changed

+14
-20
lines changed

csrc/deep_ep.cpp

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1224,18 +1224,6 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
12241224
std::optional<std::function<void()>> recv_hook = std::nullopt;
12251225
if (return_recv_hook)
12261226
recv_hook = [=]() { launcher(LOW_LATENCY_RECV_PHASE); };
1227-
// if (not use_fp8 and use_nvfp4) {
1228-
// constexpr int kNumPerChannels = 16;
1229-
// constexpr int NUM_SF_ELEMS_PER_PACK = 4;
1230-
// constexpr int mTileSize_dim_0 = 32;
1231-
// constexpr int mTileSize_dim_1 = 4;
1232-
// constexpr int mTileSize = mTileSize_dim_0 * mTileSize_dim_1;
1233-
1234-
// auto l = num_local_experts;
1235-
// auto m = num_ranks * num_max_dispatch_tokens_per_rank;
1236-
// auto rk = hidden / (kNumPerChannels * NUM_SF_ELEMS_PER_PACK);
1237-
// packed_recv_x_scales = packed_recv_x_scales.value().contiguous().view(torch::kInt).reshape({l, m, rk});
1238-
// }
12391227
// Return values
12401228
return {packed_recv_x, packed_recv_x_scales, packed_recv_x_sf_scale, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, recv_hook};
12411229
#else

csrc/kernels/internode_ll.cu

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ inline __device__ float reciprocal_approximate_ftz(float a) {
5757
return b;
5858
}
5959

60-
// float to e2m1 4bit (sign:1, exp:2, mantissa:1) quantization
60+
// Convert 1 float value into 8 e2m1 values (4bit, sign:1, exp:2, mantissa:1) quantization.
6161
__device__ inline uint8_t float_to_e2m1(float f) {
6262
// Get sign
6363
uint8_t sign = (f < 0);
@@ -92,8 +92,9 @@ inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) {
9292
return val;
9393
#else
9494
#if !(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000))
95-
#pragma message("warning: this architecture does not support cvt.rn.satfinite.e2m1x2.f32, use float_to_e2m1 instead.")
96-
#endif
95+
#pragma message("warning: this architecture does not support " \
96+
"cvt.rn.satfinite.e2m1x2.f32, use user defined " \
97+
"float_to_e2m1 to convert float values to e2m1 values.")
9798
uint32_t val = 0;
9899
float* data = reinterpret_cast<float*>(&array[0]);
99100
for (int i = 0; i < 8; ++i) {
@@ -125,7 +126,9 @@ inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) {
125126
return val;
126127
#else
127128
#if !(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000))
128-
#pragma message("warning: this architecture does not support cvt.rn.satfinite.e2m1x2.f32, use float_to_e2m1 instead.")
129+
#pragma message("warning: this architecture does not support " \
130+
"cvt.rn.satfinite.e2m1x2.f32, use user defined " \
131+
"float_to_e2m1 to convert float values to e2m1 values.")
129132
#endif
130133
uint32_t val = 0;
131134
float* data = reinterpret_cast<float*>(&array[0]);
@@ -542,9 +545,8 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, void* packed_recv_x_sf
542545
auto scale = extract_required_scale_format<kUseUE8M0>(ld_nc_global(src_scales + lane_id + 32));
543546
recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;
544547
}
545-
} else if constexpr (kUseNVFP4) {
546-
// Equivalent CuTe layout:
547-
// (num_tokens, (num_packed, num_elems_per_pack)):(num_elems_per_pack, (num_tokens * num_elems_per_pack, 1))
548+
} else if constexpr (kUseNVFP4) {
549+
// The physical layout is (l, rm, rk, 32, 4, 4).
548550
const auto src_scales = reinterpret_cast<uint8_t*>(reinterpret_cast<uint8_t*>(src_data) + hidden_bytes);
549551
const auto num_elems_per_pack = static_cast<int>(sizeof(packed_t) / sizeof(scale_t));
550552
const auto token_idx = recv_token_begin_idx + i;
@@ -557,7 +559,6 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, void* packed_recv_x_sf
557559
const auto pack_idx = j / num_elems_per_pack;
558560
const auto elem_idx = j % num_elems_per_pack;
559561
auto scale = ld_nc_global(src_scales + j);
560-
// recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;
561562
recv_x_scales[rm * token_stride * 128 + pack_idx * pack_stride * 128 + rm_res * pack_stride + elem_idx] = scale;
562563
}
563564
}

tests/test_low_latency.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,14 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
7373
if dispatch_use_fp8:
7474
packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous())
7575
elif dispatch_use_nvfp4:
76+
# For the receved x_scale, its dtype is int8, its physical layout is (l, rm, rk, 32, 4, 4).
77+
# and its logical shape is (32, 4, rm, 4, rk, l).
7678
recv_x_scale_view = packed_recv_x[1].clone()
79+
# After permute, the logical shape will be (l, rm, 32, 4, rk, 4).
7780
recv_x_scale_view = recv_x_scale_view.permute(5, 2, 0, 1, 4, 3)
81+
# After view, the logical shape will be (l, rm, 32, 4, rk), the dtype is int32.
7882
recv_x_scale_view = recv_x_scale_view.contiguous().view(torch.int32)
83+
# After view, the logical shape will be (l, rm * 32 * 4, rk), the dtype is int32.
7984
recv_x_scale_view = recv_x_scale_view.contiguous().view(num_local_experts, int(num_ranks * num_tokens), hidden // (16 * 4))
8085
packed_recv_x = (packed_recv_x[0], recv_x_scale_view, packed_recv_x[2].contiguous())
8186
else:

0 commit comments

Comments
 (0)