Skip to content

Commit 75944ab

Browse files
authored
nvfp4tensor: move imports to top of file (#3091)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent 7d6962b commit 75944ab

File tree

1 file changed

+3
-23
lines changed

1 file changed

+3
-23
lines changed

test/prototype/mx_formats/test_nvfp4_tensor.py

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@
1616
NVFP4MMConfig,
1717
)
1818
from torchao.prototype.mx_formats.nvfp4_tensor import (
19+
NVFP4Tensor,
1920
QuantizeTensorToNVFP4Kwargs,
21+
per_tensor_amax_to_scale,
22+
unpack_uint4,
2023
)
2124
from torchao.quantization.utils import compute_error
2225
from torchao.testing.utils import skip_if_rocm
@@ -45,11 +48,6 @@
4548
not torch_version_at_least("2.8.0"), reason="torch.compile requires PyTorch 2.8+"
4649
)
4750
def test_nvfp4_reconstruction(dtype, shape, use_per_tensor_scale):
48-
from torchao.prototype.mx_formats.nvfp4_tensor import (
49-
NVFP4Tensor,
50-
per_tensor_amax_to_scale,
51-
)
52-
5351
x = torch.randn(shape, dtype=dtype, device="cuda")
5452
if use_per_tensor_scale:
5553
tensor_amax = torch.max(torch.abs(x))
@@ -115,7 +113,6 @@ def test_nvfp4_swizzled_scales_construction(is_swizzled_scales, shape):
115113
Test that NVFP4Tensor can be constructed with swizzled scales and
116114
that the _is_swizzled_scales flag is set correctly.
117115
"""
118-
from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor
119116

120117
M, K = shape
121118
data = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
@@ -153,7 +150,6 @@ def test_nvfp4_swizzled_scales_slicing(slice_dim, slice_spec):
153150
Test that slicing works correctly with swizzled scales and maintains
154151
the swizzled state in the output tensor.
155152
"""
156-
from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor
157153

158154
# Use larger tensor sizes that align with swizzled requirements
159155
if slice_dim == 0:
@@ -247,7 +243,6 @@ def test_nvfp4_swizzled_scales_slicing_errors(slice_dim, slice_spec, expected_er
247243
"""
248244
Test that slicing raises appropriate errors for misaligned boundaries.
249245
"""
250-
from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor
251246

252247
M, K = 256, 4096
253248
data = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
@@ -268,7 +263,6 @@ def test_nvfp4_swizzled_scales_view_semantics():
268263
"""
269264
Test that slicing maintains proper view semantics where possible.
270265
"""
271-
from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor
272266

273267
M, K = 256, 4096
274268
data = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
@@ -295,7 +289,6 @@ def test_nvfp4_swizzled_scales_serialization():
295289
"""
296290
Test that tensor flatten/unflatten preserves the swizzled scales state.
297291
"""
298-
from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor
299292

300293
M, K = 32, 64
301294
data = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
@@ -337,7 +330,6 @@ def test_nvfp4_swizzled_scales_get_scales_method():
337330
"""
338331
Test that the get_scales() method correctly unswizzles scales when needed.
339332
"""
340-
from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor
341333

342334
M, K = 32, 64
343335
data = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
@@ -372,11 +364,6 @@ def test_nvfp4_swizzled_scales_get_scales_method():
372364
@torch.no_grad()
373365
def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype):
374366
"""Test that Triton and PyTorch NVFP4 quantization produce equivalent results."""
375-
from torchao.prototype.mx_formats.nvfp4_tensor import (
376-
NVFP4Tensor,
377-
per_tensor_amax_to_scale,
378-
unpack_uint4,
379-
)
380367

381368
torch.manual_seed(42)
382369
x = torch.randn(M, N, dtype=dtype, device="cuda")
@@ -462,11 +449,6 @@ def test_nvfp4_matmul_with_amax(
462449
use_triton_kernel: bool,
463450
shapes: tuple,
464451
):
465-
from torchao.prototype.mx_formats.nvfp4_tensor import (
466-
NVFP4Tensor,
467-
per_tensor_amax_to_scale,
468-
)
469-
470452
# DYNAMIC mode requires SM100+, but WEIGHT_ONLY works on older GPUs
471453
if mm_config == NVFP4MMConfig.DYNAMIC and not is_sm_at_least_100():
472454
pytest.skip("CUDA capability >= 10.0 required for DYNAMIC float4 gemm")
@@ -530,8 +512,6 @@ def test_nvfp4_matmul_with_amax(
530512
not torch_version_at_least("2.8.0"), reason="NVFP4 requires PyTorch 2.8+"
531513
)
532514
def test_nvfp4_to_copy():
533-
from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor
534-
535515
x = NVFP4Tensor.to_nvfp4(torch.randn((32, 128))).cuda()
536516
y = torch.ops.aten._to_copy(x, dtype=torch.bfloat16)
537517
assert torch.equal(x.qdata, y.qdata)

0 commit comments

Comments
 (0)