@@ -81,10 +81,11 @@ __device__ inline uint8_t float_to_e2m1(float f) {
8181 return (sign << 3 ) | (exp << 1 ) | mant;
8282}
8383
84-
8584// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
8685inline __device__ uint32_t fp32_vec_to_e2m1 (float2 (&array)[4]) {
87- #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
86+ // PTX instructions used here requires sm100a.
87+ #if CUDA_VERSION >= 12080
88+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) && __CUDA_ARCH_HAS_FEATURE__(SM100_ALL)
8889 uint32_t val;
8990 asm volatile (
9091 " {\n "
@@ -99,13 +100,16 @@ inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) {
99100 " mov.b32 %0, {byte0, byte1, byte2, byte3};\n "
100101 " }"
101102 : " =r" (val)
102- : " f" (array[0 ].x ), " f" (array[0 ].y ), " f" (array[1 ].x ), " f" (array[1 ].y ), " f" (array[2 ].x ),
103- " f" (array[2 ].y ), " f" (array[3 ].x ), " f" (array[3 ].y ));
103+ : " f" (array[0 ].x ),
104+ " f" (array[0 ].y ),
105+ " f" (array[1 ].x ),
106+ " f" (array[1 ].y ),
107+ " f" (array[2 ].x ),
108+ " f" (array[2 ].y ),
109+ " f" (array[3 ].x ),
110+ " f" (array[3 ].y ));
104111 return val;
105112 #else
106- #if !(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000))
107- #pragma message("warning: this architecture does not support cvt.rn.satfinite.e2m1x2.f32, use float_to_e2m1 instead.")
108- #endif
109113 uint32_t val = 0 ;
110114 float2 * data = reinterpret_cast <float2 *>(&array[0 ]);
111115 for (int i = 0 ; i < 4 ; ++i) {
@@ -114,7 +118,8 @@ inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) {
114118 }
115119 return val;
116120 #endif
117- }
121+ #endif
122+ }
118123
119124constexpr int CVT_ELTS_PER_THREAD = 8 ;
120125// Quantizes the provided PackedVec into the uint32_t output
@@ -195,7 +200,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
195200 int * packed_recv_count,
196201 int * cumulative_local_expert_recv_stats,
197202 int64_t * dispatch_wait_recv_cost_stats,
198- const float * x_sf_scale ,
203+ const float * x_global_scales ,
199204 void * rdma_recv_x, int * rdma_recv_count, void * rdma_x,
200205 const void * x, const int64_t * topk_idx,
201206 int * atomic_counter_per_expert, int * atomic_finish_counter_per_expert,
@@ -270,8 +275,8 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
270275 float SFScaleVal = 1 .0f ;
271276 if constexpr (kUseNVFP4 ) {
272277 // Get scaling value;
273- EP_DEVICE_ASSERT (x_sf_scale != nullptr );
274- SFScaleVal = *(static_cast <const float *>(x_sf_scale ));
278+ EP_DEVICE_ASSERT (x_global_scales != nullptr );
279+ SFScaleVal = *(static_cast <const float *>(x_global_scales ));
275280 }
276281
277282 // FP8 or NVFP4 cast
@@ -537,14 +542,14 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
537542 int * packed_recv_count,
538543 int * cumulative_local_expert_recv_stats,
539544 int64_t * dispatch_wait_recv_cost_stats,
540- const float * x_sf_scale ,
545+ const float * x_global_scales ,
541546 void * rdma_recv_x, int * rdma_recv_count, void * rdma_x,
542547 const void * x, const int64_t * topk_idx,
543548 int * next_clean, int num_next_clean_int,
544549 int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
545550 int num_topk, int num_experts, int rank, int num_ranks,
546551 bool use_fp8, bool round_scale, bool use_ue8m0,
547- bool use_nvfp4, bool use_ue8m0_for_sf ,
552+ bool use_nvfp4, bool use_ue8m0_for_nvfp4_x_scale ,
548553 void * workspace, int num_device_sms,
549554 cudaStream_t stream, int phases) {
550555 constexpr int kNumMaxTopK = 9 ;
@@ -572,17 +577,17 @@ if (use_fp8 and not use_ue8m0) \
572577 dispatch_func = dispatch<true , false , false , false , hidden>; \
573578if (use_fp8 and use_ue8m0) \
574579 dispatch_func = dispatch<true , true , false , false , hidden>; \
575- if (use_nvfp4 and not use_ue8m0_for_sf ) \
580+ if (use_nvfp4 and not use_ue8m0_for_nvfp4_x_scale ) \
576581 dispatch_func = dispatch<false , false , true , false , hidden>; \
577- if (use_nvfp4 and use_ue8m0_for_sf ) \
582+ if (use_nvfp4 and use_ue8m0_for_nvfp4_x_scale ) \
578583 dispatch_func = dispatch<false , false , true , true , hidden>; \
579584LAUNCH_KERNEL (&cfg, dispatch_func, \
580585 packed_recv_x, packed_recv_x_scales, \
581586 packed_recv_src_info, packed_recv_layout_range, \
582587 packed_recv_count, \
583588 cumulative_local_expert_recv_stats, \
584589 dispatch_wait_recv_cost_stats, \
585- x_sf_scale , \
590+ x_global_scales , \
586591 rdma_recv_x, rdma_recv_count, rdma_x, \
587592 x, topk_idx, \
588593 atomic_counter_per_expert, atomic_finish_counter_per_expert, \
0 commit comments