@@ -57,7 +57,7 @@ inline __device__ float reciprocal_approximate_ftz(float a) {
5757 return b;
5858}
5959
60- // float to e2m1 4bit ( sign:1, exp:2, mantissa:1) quantization
60+ // Convert 1 float value into 8 e2m1 values (4bit, sign:1, exp:2, mantissa:1) quantization.
6161__device__ inline uint8_t float_to_e2m1 (float f) {
6262 // Get sign
6363 uint8_t sign = (f < 0 );
@@ -92,8 +92,9 @@ inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) {
9292 return val;
9393 #else
9494 #if !(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000))
95- #pragma message("warning: this architecture does not support cvt.rn.satfinite.e2m1x2.f32, use float_to_e2m1 instead.")
96- #endif
95+ #pragma message("warning: this architecture does not support " \
96+ " cvt.rn.satfinite.e2m1x2.f32, use user defined " \
97+ " float_to_e2m1 to convert float values to e2m1 values." )
9798 uint32_t val = 0 ;
9899 float * data = reinterpret_cast <float *>(&array[0 ]);
99100 for (int i = 0 ; i < 8 ; ++i) {
@@ -125,7 +126,9 @@ inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) {
125126 return val;
126127 #else
127128 #if !(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000))
128- #pragma message("warning: this architecture does not support cvt.rn.satfinite.e2m1x2.f32, use float_to_e2m1 instead.")
129+ #pragma message("warning: this architecture does not support " \
130+ " cvt.rn.satfinite.e2m1x2.f32, use user defined " \
131+ " float_to_e2m1 to convert float values to e2m1 values." )
129132 #endif
130133 uint32_t val = 0 ;
131134 float * data = reinterpret_cast <float *>(&array[0 ]);
@@ -542,9 +545,8 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, void* packed_recv_x_sf
542545 auto scale = extract_required_scale_format<kUseUE8M0 >(ld_nc_global (src_scales + lane_id + 32 ));
543546 recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;
544547 }
545- } else if constexpr (kUseNVFP4 ) {
546- // Equivalent CuTe layout:
547- // (num_tokens, (num_packed, num_elems_per_pack)):(num_elems_per_pack, (num_tokens * num_elems_per_pack, 1))
548+ } else if constexpr (kUseNVFP4 ) {
549+ // The physical layout is (l, rm, rk, 32, 4, 4).
548550 const auto src_scales = reinterpret_cast <uint8_t *>(reinterpret_cast <uint8_t *>(src_data) + hidden_bytes);
549551 const auto num_elems_per_pack = static_cast <int >(sizeof (packed_t ) / sizeof (scale_t ));
550552 const auto token_idx = recv_token_begin_idx + i;
@@ -557,7 +559,6 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, void* packed_recv_x_sf
557559 const auto pack_idx = j / num_elems_per_pack;
558560 const auto elem_idx = j % num_elems_per_pack;
559561 auto scale = ld_nc_global (src_scales + j);
560- // recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;
561562 recv_x_scales[rm * token_stride * 128 + pack_idx * pack_stride * 128 + rm_res * pack_stride + elem_idx] = scale;
562563 }
563564 }
0 commit comments