Skip to content

Commit e8471c6

Browse files
authored
add nvfp4 cast benchmarks (#3188)
Update [ghstack-poisoned]
1 parent 01232b9 commit e8471c6

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed

benchmarks/mx_formats/cast_bench.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
triton_to_mxfp8_dim1,
1818
)
1919
from torchao.prototype.mx_formats.mx_tensor import to_mx
20+
from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor
2021

2122
torch.manual_seed(0)
2223

@@ -76,6 +77,18 @@ def to_mx_dim1_reference(
7677
return data_d1.t(), scale_d1
7778

7879

80+
def to_nvfp4_reference(x_hp):
81+
nvfp4_tensor = NVFP4Tensor.to_nvfp4(x_hp, use_triton_kernel=False)
82+
return nvfp4_tensor.qdata, nvfp4_tensor.scale
83+
84+
85+
def to_nvfp4_reference_triton_swizzle(x_hp):
86+
nvfp4_tensor = NVFP4Tensor.to_nvfp4(
87+
x_hp, use_triton_kernel=True, is_swizzled_scales=True
88+
)
89+
return nvfp4_tensor.qdata, nvfp4_tensor.scale
90+
91+
7992
def benchmark_cuda_function_in_microseconds(f, *args):
8093
return do_bench(lambda: f(*args), return_mode="median") * 1e3
8194

@@ -99,6 +112,8 @@ def run(
99112
"dim0_mxfp4_floor",
100113
"dim0_mxfp8_rceil",
101114
"dim0_mxfp8_triton_floor",
115+
"dim0_nvfp4",
116+
"dim0_nvfp4_triton_swizzle",
102117
"dim1_mxfp8_floor",
103118
"dim1_mxfp8_rceil",
104119
"dim1_mxfp8_triton_floor",
@@ -240,6 +255,37 @@ def run(
240255
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
241256
bps = (bytes_r + bytes_w) / (time_us / 1e6)
242257

258+
elif mode == "dim0_nvfp4":
259+
to_nvfp4_reference_c = torch.compile(to_nvfp4_reference)
260+
y_d0, s_d0 = to_nvfp4_reference_c(x, use_triton_kernel=False)
261+
262+
for _ in range(2):
263+
__ = to_nvfp4_reference_c(x, use_triton_kernel=False)
264+
time_us = benchmark_cuda_function_in_microseconds(
265+
lambda x: to_nvfp4_reference_c(x, use_triton_kernel=False),
266+
x,
267+
)
268+
assert y_d0.dtype == torch.uint8
269+
assert s_d0.dtype == torch.float8_e4m3fn
270+
bytes_r = x.numel() * bytes_per_el_bf16
271+
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
272+
bps = (bytes_r + bytes_w) / (time_us / 1e6)
273+
274+
elif mode == "dim0_nvfp4_triton_swizzle":
275+
y_d0, s_d0 = to_nvfp4_reference_triton_swizzle(x)
276+
277+
for _ in range(2):
278+
__ = to_nvfp4_reference_triton_swizzle(x)
279+
time_us = benchmark_cuda_function_in_microseconds(
280+
lambda x: to_nvfp4_reference_triton_swizzle(x),
281+
x,
282+
)
283+
assert y_d0.dtype == torch.uint8
284+
assert s_d0.dtype == torch.float8_e4m3fn
285+
bytes_r = x.numel() * bytes_per_el_bf16
286+
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
287+
bps = (bytes_r + bytes_w) / (time_us / 1e6)
288+
243289
elif mode == "dim1_mxfp8_floor":
244290
to_mx_dim1_reference_c = torch.compile(to_mx_dim1_reference)
245291
y_d1, s_d1 = to_mx_dim1_reference_c(x, BLOCK_SIZE)

0 commit comments

Comments
 (0)