Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 84 additions & 3 deletions test/prototype/mx_formats/test_nvfp4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
per_tensor_amax_to_scale,
unpack_uint4,
)
from torchao.prototype.mx_formats.utils import ceil_div
from torchao.quantization.utils import compute_error
from torchao.testing.utils import skip_if_rocm
from torchao.utils import (
Expand Down Expand Up @@ -127,8 +128,8 @@ def test_nvfp4_swizzled_scales_construction(is_swizzled_scales, shape):
"slice_dim,slice_spec",
[
# Row slicing - must align with 128-row boundaries
pytest.param(0, slice(0, 128), id="slice_rows[0:128]"),
pytest.param(0, slice(128, 256), id="slice_rows[128:256]"),
# pytest.param(0, slice(0, 128), id="slice_rows[0:128]"),
# pytest.param(0, slice(128, 256), id="slice_rows[128:256]"),
# Column slicing - must align with 64-column boundaries (4 scale columns * 16 block_size)
pytest.param(1, slice(0, 64), id="slice_cols[0:64]"),
pytest.param(1, slice(64, 128), id="slice_cols[64:128]"),
Expand Down Expand Up @@ -162,12 +163,29 @@ def test_nvfp4_swizzled_scales_slicing(slice_dim, slice_spec):
data = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)

tensor = NVFP4Tensor.to_nvfp4(data, is_swizzled_scales=True)
# tensor.to_dtype(torch.bfloat16)
assert tensor._is_swizzled_scales == True

print(
"before",
tensor.shape,
tensor.qdata.shape,
tensor._scale_e4m3.shape,
tensor._scale_e4m3.is_contiguous(),
)
print(tensor._scale_e4m3[0:128, 0:8])
if slice_dim == 0:
sliced_tensor = tensor[slice_spec, :]
else:
sliced_tensor = tensor[:, slice_spec]
print(
"after",
sliced_tensor.shape,
sliced_tensor.qdata.shape,
sliced_tensor._scale_e4m3.shape,
tensor._scale_e4m3.is_contiguous(),
)
print(tensor._scale_e4m3[0:128, 0:8])

# Verify sliced tensor maintains swizzled state
assert sliced_tensor._is_swizzled_scales == True
Expand Down Expand Up @@ -421,7 +439,8 @@ def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype):
@pytest.mark.parametrize("compile", [False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("inpt_dtype", [torch.bfloat16, torch.float32])
@pytest.mark.parametrize("use_triton_kernel", [True, False])
# @pytest.mark.parametrize("use_triton_kernel", [True, False])
@pytest.mark.parametrize("use_triton_kernel", [False])
@pytest.mark.parametrize(
"shapes",
[
Expand Down Expand Up @@ -525,3 +544,65 @@ def test_nvfp4_to_copy():
assert x.act_quant_kwargs == y.act_quant_kwargs
assert x.dtype == torch.float32
assert y.dtype == torch.bfloat16


@pytest.mark.parametrize("transpose", [False, True])
# @pytest.mark.parametrize("transpose", [True])
# @pytest.mark.parametrize("transpose", [False])
@pytest.mark.parametrize("use_triton_kernel", [False, True])
# @pytest.mark.parametrize("use_triton_kernel", [False])
# @pytest.mark.parametrize("use_triton_kernel", [True])
@pytest.mark.parametrize("is_swizzled_scales", [False, True])
# @pytest.mark.parametrize("is_swizzled_scales", [True])
@pytest.mark.parametrize(
"mk",
(
(128, 64),
(128 + 16, 64),
(128, 64 + 16),
(128 + 16, 64 + 16),
),
)
# @pytest.mark.parametrize("mk", ((128 + 16, 64),))
def test_scale_shape_matches_qdata(
transpose, use_triton_kernel, is_swizzled_scales, mk
):
if use_triton_kernel and not is_swizzled_scales:
pytest.skip("triton kernel requires swizzled scales")

M, K = mk

block_size = 16

# TODO(this PR): test larger tensors that don't exactly map to (128, 64) tiles,
# to test the padding logic
# context: https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
x_hp = torch.randn(M, K, device="cuda")
x = NVFP4Tensor.to_nvfp4(
x_hp, is_swizzled_scales=is_swizzled_scales, use_triton_kernel=use_triton_kernel
)

m_dim, k_dim = 0, 1
if transpose:
x_hp = x_hp.t()
x = x.t()
m_dim, k_dim = 1, 0

orig_m = x_hp.shape[m_dim]
expected_padded_m = orig_m
if is_swizzled_scales:
expected_padded_m = ceil_div(orig_m, 128) * 128
actual_padded_m = x._scale_e4m3.shape[m_dim]
assert expected_padded_m == actual_padded_m, (
f"incompatible padded shape for dim {m_dim}: {x.shape} and {x._scale_e4m3.shape}"
)

orig_k = x_hp.shape[k_dim]
expected_padded_k = orig_k // block_size
if is_swizzled_scales:
expected_padded_k = ceil_div(orig_k // block_size, 4) * 4
actual_padded_k = x._scale_e4m3.shape[k_dim]

assert expected_padded_k == actual_padded_k, (
f"incompatible padded shape for dim {k_dim}: {x.shape} and {x._scale_e4m3.shape}"
)
64 changes: 41 additions & 23 deletions torchao/prototype/mx_formats/nvfp4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,17 @@ def to_nvfp4(
if is_swizzled_scales:
M, K = data_hp.shape[0], data_hp.shape[1]
scale_shape = (M, K // block_size)
blockwise_scales = to_blocked(
blockwise_scales.view(scale_shape)
).flatten()
# print(1, blockwise_scales.shape)
blockwise_scales = blockwise_scales.view(scale_shape)
# print(2, blockwise_scales.shape, blockwise_scales)
blockwise_scales = to_blocked(blockwise_scales)
# print(3, blockwise_scales.shape, blockwise_scales)

# match shape of data_hp
scale_M = ceil_div(data_hp.shape[0], 128) * 128
scale_K = ceil_div(data_hp.shape[1] // 16, 4) * 4
blockwise_scales = blockwise_scales.view(scale_M, scale_K)
# print(4, blockwise_scales.shape, blockwise_scales)

return NVFP4Tensor(
data_lp,
Expand Down Expand Up @@ -220,6 +228,7 @@ def to_dtype(self, target_dtype: torch.dtype) -> torch.Tensor:
data_unpacked = unpack_uint4(data.contiguous().view(torch.uint8))
data_f32 = f4_unpacked_to_f32(data_unpacked)

# next: debug scale shape here
data_f32 = data_f32.view(M, K // self._block_size, self._block_size)
scale_e4m3_reshaped = self.get_hp_scales().view(M, K // self._block_size, 1)
data_scaled = data_f32 * scale_e4m3_reshaped.to(torch.float32)
Expand All @@ -237,15 +246,17 @@ def get_hp_scales(self) -> torch.Tensor:
torch.Tensor: Scales of the NVFP4Tensor
"""
is_transposed = self.qdata.stride(0) < self.qdata.stride(1)
print("is_transposed", is_transposed)
if is_transposed:
M, K = self.shape[1], self.shape[0]
scale_e4m3 = self._scale_e4m3.t()
else:
M, K = self.shape[0], self.shape[1]
scale_e4m3 = self._scale_e4m3

if self._is_swizzled_scales:
scale_e4m3 = from_blocked(self._scale_e4m3, M, K // self._block_size)
else:
scale_e4m3 = self._scale_e4m3
# import pdb; pdb.set_trace()
scale_e4m3 = from_blocked(scale_e4m3, M, K // self._block_size)

return (
scale_e4m3.to(self._orig_dtype)
Expand Down Expand Up @@ -366,6 +377,7 @@ def nvfp4_slice(func, types, args, kwargs):
raise ValueError("Only support aten.slice with step=1")

assert x.qdata.is_contiguous(), "Only support contiguous data for now"
assert x._scale_e4m3.is_contiguous(), "Only support contiguous scale for now"

M, K = x.shape[0], x.shape[1]

Expand Down Expand Up @@ -407,7 +419,7 @@ def nvfp4_slice(func, types, args, kwargs):
else None
)

sliced_scale = aten.slice.Tensor(x._scale_e4m3, 0, start_idx, end_idx, 1)
sliced_scale = aten.slice.Tensor(x._scale_e4m3, 0, start, end, 1)
sliced_data = aten.slice.Tensor(x.qdata, 0, start, end, step)

elif dim == 1:
Expand Down Expand Up @@ -452,20 +464,24 @@ def nvfp4_slice(func, types, args, kwargs):
# Full width - no slicing needed
sliced_scale = x._scale_e4m3
else:
# Extract specific column blocks from each row block
# Each row block in swizzled format contains n_col_blocks chunks of (32, 16)
elements_per_row_block = n_col_blocks * elements_per_block

# Build list of slices to extract
slices_to_extract = []
for row_block in range(n_row_blocks):
row_start = row_block * elements_per_row_block
col_start = row_start + start_col_block * elements_per_block
col_end = row_start + end_col_block * elements_per_block
slices_to_extract.append(x._scale_e4m3[col_start:col_end])

# Concatenate all the slices
sliced_scale = torch.cat(slices_to_extract, dim=0)
if False:
# Extract specific column blocks from each row block
# Each row block in swizzled format contains n_col_blocks chunks of (32, 16)
elements_per_row_block = n_col_blocks * elements_per_block

# Build list of slices to extract
slices_to_extract = []
for row_block in range(n_row_blocks):
row_start = row_block * elements_per_row_block
col_start = row_start + start_col_block * elements_per_block
col_end = row_start + end_col_block * elements_per_block
slices_to_extract.append(x._scale_e4m3[col_start:col_end])

# Concatenate all the slices
sliced_scale = torch.cat(slices_to_extract, dim=0)
sliced_scale = aten.slice.Tensor(
x._scale_e4m3, dim, start_scale_col, end_scale_col, step
).contiguous()

# Slice the data tensor
packed_start = None if start is None else start // 2
Expand Down Expand Up @@ -537,7 +553,7 @@ def nvfp4_t(func, types, args, kwargs):
old = args[0]
new = NVFP4Tensor(
old.qdata.t(),
old._scale_e4m3,
old._scale_e4m3.t(),
old._block_size,
old._orig_dtype,
old._per_tensor_scale,
Expand Down Expand Up @@ -577,6 +593,8 @@ def _addmm_nvfp4_dispatch(
"""
assert a.qdata.is_contiguous()
assert b.qdata.t().is_contiguous()
assert a._scale_e4m3.is_contiguous()
assert b._scale_e4m3.t().is_contiguous()
assert a._block_size == 16, f"NVFP4 requires block_size=16, got {a._block_size}"
assert b._block_size == 16, f"NVFP4 requires block_size=16, got {b._block_size}"

Expand Down Expand Up @@ -615,7 +633,7 @@ def _addmm_nvfp4_dispatch(
a.qdata.view(torch.float4_e2m1fn_x2),
b.qdata.view(torch.float4_e2m1fn_x2),
a_scale_blocked.view(torch.float8_e4m3fn),
b_scale_blocked.view(torch.float8_e4m3fn),
b_scale_blocked.t().view(torch.float8_e4m3fn),
bias=None if should_add_bias_separately else bias,
out_dtype=a._orig_dtype,
# scale_result=scale_result, # Not supported yet
Expand Down
5 changes: 3 additions & 2 deletions torchao/prototype/mx_formats/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def ceil_div(a, b):

def to_blocked(input_matrix, use_triton_kernel: bool = False) -> Tensor:
"""
Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern.
Rearrange a large matrix by breaking it into blocks, padding, and applying the rearrangement pattern.

See:
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
Expand All @@ -37,7 +37,7 @@ def to_blocked(input_matrix, use_triton_kernel: bool = False) -> Tensor:
torch.compile

Returns:
Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4))
Rearranged tensor of shape (32*ceil_div(H,128) * 16*ceil_div(W,4))
"""
if use_triton_kernel:
return triton_mx_block_rearrange(input_matrix).flatten()
Expand Down Expand Up @@ -83,6 +83,7 @@ def from_blocked(
"""
n_row_blocks = ceil_div(original_rows, 128)
n_col_blocks = ceil_div(original_cols, 4)
# import pdb; pdb.set_trace()

rearranged = blocked_tensor.view(n_row_blocks * n_col_blocks, 32, 16)

Expand Down
Loading