@@ -57,7 +57,7 @@ inline __device__ float reciprocal_approximate_ftz(float a) {
57
57
return b;
58
58
}
59
59
60
- // float to e2m1 4bit ( sign:1, exp:2, mantissa:1) quantization
60
+ // Convert 1 float value into 8 e2m1 values (4bit, sign:1, exp:2, mantissa:1) quantization.
61
61
__device__ inline uint8_t float_to_e2m1 (float f) {
62
62
// Get sign
63
63
uint8_t sign = (f < 0 );
@@ -92,8 +92,9 @@ inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) {
92
92
return val;
93
93
#else
94
94
#if !(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000))
95
- #pragma message("warning: this architecture does not support cvt.rn.satfinite.e2m1x2.f32, use float_to_e2m1 instead.")
96
- #endif
95
+ #pragma message("warning: this architecture does not support " \
96
+ " cvt.rn.satfinite.e2m1x2.f32, use user defined " \
97
+ " float_to_e2m1 to convert float values to e2m1 values." )
97
98
uint32_t val = 0 ;
98
99
float * data = reinterpret_cast <float *>(&array[0 ]);
99
100
for (int i = 0 ; i < 8 ; ++i) {
@@ -125,7 +126,9 @@ inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) {
125
126
return val;
126
127
#else
127
128
#if !(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000))
128
- #pragma message("warning: this architecture does not support cvt.rn.satfinite.e2m1x2.f32, use float_to_e2m1 instead.")
129
+ #pragma message("warning: this architecture does not support " \
130
+ " cvt.rn.satfinite.e2m1x2.f32, use user defined " \
131
+ " float_to_e2m1 to convert float values to e2m1 values." )
129
132
#endif
130
133
uint32_t val = 0 ;
131
134
float * data = reinterpret_cast <float *>(&array[0 ]);
@@ -542,9 +545,8 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, void* packed_recv_x_sf
542
545
auto scale = extract_required_scale_format<kUseUE8M0 >(ld_nc_global (src_scales + lane_id + 32 ));
543
546
recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;
544
547
}
545
- } else if constexpr (kUseNVFP4 ) {
546
- // Equivalent CuTe layout:
547
- // (num_tokens, (num_packed, num_elems_per_pack)):(num_elems_per_pack, (num_tokens * num_elems_per_pack, 1))
548
+ } else if constexpr (kUseNVFP4 ) {
549
+ // The physical layout is (l, rm, rk, 32, 4, 4).
548
550
const auto src_scales = reinterpret_cast <uint8_t *>(reinterpret_cast <uint8_t *>(src_data) + hidden_bytes);
549
551
const auto num_elems_per_pack = static_cast <int >(sizeof (packed_t ) / sizeof (scale_t ));
550
552
const auto token_idx = recv_token_begin_idx + i;
@@ -557,7 +559,6 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, void* packed_recv_x_sf
557
559
const auto pack_idx = j / num_elems_per_pack;
558
560
const auto elem_idx = j % num_elems_per_pack;
559
561
auto scale = ld_nc_global (src_scales + j);
560
- // recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;
561
562
recv_x_scales[rm * token_stride * 128 + pack_idx * pack_stride * 128 + rm_res * pack_stride + elem_idx] = scale;
562
563
}
563
564
}
0 commit comments