16
16
NVFP4MMConfig ,
17
17
)
18
18
from torchao .prototype .mx_formats .nvfp4_tensor import (
19
+ NVFP4Tensor ,
19
20
QuantizeTensorToNVFP4Kwargs ,
21
+ per_tensor_amax_to_scale ,
22
+ unpack_uint4 ,
20
23
)
21
24
from torchao .quantization .utils import compute_error
22
25
from torchao .testing .utils import skip_if_rocm
45
48
not torch_version_at_least ("2.8.0" ), reason = "torch.compile requires PyTorch 2.8+"
46
49
)
47
50
def test_nvfp4_reconstruction (dtype , shape , use_per_tensor_scale ):
48
- from torchao .prototype .mx_formats .nvfp4_tensor import (
49
- NVFP4Tensor ,
50
- per_tensor_amax_to_scale ,
51
- )
52
-
53
51
x = torch .randn (shape , dtype = dtype , device = "cuda" )
54
52
if use_per_tensor_scale :
55
53
tensor_amax = torch .max (torch .abs (x ))
@@ -115,7 +113,6 @@ def test_nvfp4_swizzled_scales_construction(is_swizzled_scales, shape):
115
113
Test that NVFP4Tensor can be constructed with swizzled scales and
116
114
that the _is_swizzled_scales flag is set correctly.
117
115
"""
118
- from torchao .prototype .mx_formats .nvfp4_tensor import NVFP4Tensor
119
116
120
117
M , K = shape
121
118
data = torch .randn (M , K , device = "cuda" , dtype = torch .bfloat16 )
@@ -153,7 +150,6 @@ def test_nvfp4_swizzled_scales_slicing(slice_dim, slice_spec):
153
150
Test that slicing works correctly with swizzled scales and maintains
154
151
the swizzled state in the output tensor.
155
152
"""
156
- from torchao .prototype .mx_formats .nvfp4_tensor import NVFP4Tensor
157
153
158
154
# Use larger tensor sizes that align with swizzled requirements
159
155
if slice_dim == 0 :
@@ -247,7 +243,6 @@ def test_nvfp4_swizzled_scales_slicing_errors(slice_dim, slice_spec, expected_er
247
243
"""
248
244
Test that slicing raises appropriate errors for misaligned boundaries.
249
245
"""
250
- from torchao .prototype .mx_formats .nvfp4_tensor import NVFP4Tensor
251
246
252
247
M , K = 256 , 4096
253
248
data = torch .randn (M , K , device = "cuda" , dtype = torch .bfloat16 )
@@ -268,7 +263,6 @@ def test_nvfp4_swizzled_scales_view_semantics():
268
263
"""
269
264
Test that slicing maintains proper view semantics where possible.
270
265
"""
271
- from torchao .prototype .mx_formats .nvfp4_tensor import NVFP4Tensor
272
266
273
267
M , K = 256 , 4096
274
268
data = torch .randn (M , K , device = "cuda" , dtype = torch .bfloat16 )
@@ -295,7 +289,6 @@ def test_nvfp4_swizzled_scales_serialization():
295
289
"""
296
290
Test that tensor flatten/unflatten preserves the swizzled scales state.
297
291
"""
298
- from torchao .prototype .mx_formats .nvfp4_tensor import NVFP4Tensor
299
292
300
293
M , K = 32 , 64
301
294
data = torch .randn (M , K , device = "cuda" , dtype = torch .bfloat16 )
@@ -337,7 +330,6 @@ def test_nvfp4_swizzled_scales_get_scales_method():
337
330
"""
338
331
Test that the get_scales() method correctly unswizzles scales when needed.
339
332
"""
340
- from torchao .prototype .mx_formats .nvfp4_tensor import NVFP4Tensor
341
333
342
334
M , K = 32 , 64
343
335
data = torch .randn (M , K , device = "cuda" , dtype = torch .bfloat16 )
@@ -372,11 +364,6 @@ def test_nvfp4_swizzled_scales_get_scales_method():
372
364
@torch .no_grad ()
373
365
def test_triton_nvfp4_quantize_equivalence (M , N , use_per_tensor_scale , dtype ):
374
366
"""Test that Triton and PyTorch NVFP4 quantization produce equivalent results."""
375
- from torchao .prototype .mx_formats .nvfp4_tensor import (
376
- NVFP4Tensor ,
377
- per_tensor_amax_to_scale ,
378
- unpack_uint4 ,
379
- )
380
367
381
368
torch .manual_seed (42 )
382
369
x = torch .randn (M , N , dtype = dtype , device = "cuda" )
@@ -462,11 +449,6 @@ def test_nvfp4_matmul_with_amax(
462
449
use_triton_kernel : bool ,
463
450
shapes : tuple ,
464
451
):
465
- from torchao .prototype .mx_formats .nvfp4_tensor import (
466
- NVFP4Tensor ,
467
- per_tensor_amax_to_scale ,
468
- )
469
-
470
452
# DYNAMIC mode requires SM100+, but WEIGHT_ONLY works on older GPUs
471
453
if mm_config == NVFP4MMConfig .DYNAMIC and not is_sm_at_least_100 ():
472
454
pytest .skip ("CUDA capability >= 10.0 required for DYNAMIC float4 gemm" )
@@ -530,8 +512,6 @@ def test_nvfp4_matmul_with_amax(
530
512
not torch_version_at_least ("2.8.0" ), reason = "NVFP4 requires PyTorch 2.8+"
531
513
)
532
514
def test_nvfp4_to_copy ():
533
- from torchao .prototype .mx_formats .nvfp4_tensor import NVFP4Tensor
534
-
535
515
x = NVFP4Tensor .to_nvfp4 (torch .randn ((32 , 128 ))).cuda ()
536
516
y = torch .ops .aten ._to_copy (x , dtype = torch .bfloat16 )
537
517
assert torch .equal (x .qdata , y .qdata )
0 commit comments