1717 triton_to_mxfp8_dim1 ,
1818)
1919from torchao .prototype .mx_formats .mx_tensor import to_mx
20+ from torchao .prototype .mx_formats .nvfp4_tensor import NVFP4Tensor
2021
2122torch .manual_seed (0 )
2223
@@ -76,6 +77,18 @@ def to_mx_dim1_reference(
7677 return data_d1 .t (), scale_d1
7778
7879
80+ def to_nvfp4_reference (x_hp ):
81+ nvfp4_tensor = NVFP4Tensor .to_nvfp4 (x_hp , use_triton_kernel = False )
82+ return nvfp4_tensor .qdata , nvfp4_tensor .scale
83+
84+
85+ def to_nvfp4_reference_triton_swizzle (x_hp ):
86+ nvfp4_tensor = NVFP4Tensor .to_nvfp4 (
87+ x_hp , use_triton_kernel = True , is_swizzled_scales = True
88+ )
89+ return nvfp4_tensor .qdata , nvfp4_tensor .scale
90+
91+
7992def benchmark_cuda_function_in_microseconds (f , * args ):
8093 return do_bench (lambda : f (* args ), return_mode = "median" ) * 1e3
8194
@@ -99,6 +112,8 @@ def run(
99112 "dim0_mxfp4_floor" ,
100113 "dim0_mxfp8_rceil" ,
101114 "dim0_mxfp8_triton_floor" ,
115+ "dim0_nvfp4" ,
116+ "dim0_nvfp4_triton_swizzle" ,
102117 "dim1_mxfp8_floor" ,
103118 "dim1_mxfp8_rceil" ,
104119 "dim1_mxfp8_triton_floor" ,
@@ -240,6 +255,37 @@ def run(
240255 bytes_w = (y_d0 .numel () + s_d0 .numel ()) * bytes_per_el_fp8
241256 bps = (bytes_r + bytes_w ) / (time_us / 1e6 )
242257
258+ elif mode == "dim0_nvfp4" :
259+ to_nvfp4_reference_c = torch .compile (to_nvfp4_reference )
260+ y_d0 , s_d0 = to_nvfp4_reference_c (x , use_triton_kernel = False )
261+
262+ for _ in range (2 ):
263+ __ = to_nvfp4_reference_c (x , use_triton_kernel = False )
264+ time_us = benchmark_cuda_function_in_microseconds (
265+ lambda x : to_nvfp4_reference_c (x , use_triton_kernel = False ),
266+ x ,
267+ )
268+ assert y_d0 .dtype == torch .uint8
269+ assert s_d0 .dtype == torch .float8_e4m3fn
270+ bytes_r = x .numel () * bytes_per_el_bf16
271+ bytes_w = (y_d0 .numel () + s_d0 .numel ()) * bytes_per_el_fp8
272+ bps = (bytes_r + bytes_w ) / (time_us / 1e6 )
273+
274+ elif mode == "dim0_nvfp4_triton_swizzle" :
275+ y_d0 , s_d0 = to_nvfp4_reference_triton_swizzle (x )
276+
277+ for _ in range (2 ):
278+ __ = to_nvfp4_reference_triton_swizzle (x )
279+ time_us = benchmark_cuda_function_in_microseconds (
280+ lambda x : to_nvfp4_reference_triton_swizzle (x ),
281+ x ,
282+ )
283+ assert y_d0 .dtype == torch .uint8
284+ assert s_d0 .dtype == torch .float8_e4m3fn
285+ bytes_r = x .numel () * bytes_per_el_bf16
286+ bytes_w = (y_d0 .numel () + s_d0 .numel ()) * bytes_per_el_fp8
287+ bps = (bytes_r + bytes_w ) / (time_us / 1e6 )
288+
243289 elif mode == "dim1_mxfp8_floor" :
244290 to_mx_dim1_reference_c = torch .compile (to_mx_dim1_reference )
245291 y_d1 , s_d1 = to_mx_dim1_reference_c (x , BLOCK_SIZE )
0 commit comments