diff --git a/test/prototype/mx_formats/test_nvfp4_tensor.py b/test/prototype/mx_formats/test_nvfp4_tensor.py index 62cd1b88ad..c23fee486d 100644 --- a/test/prototype/mx_formats/test_nvfp4_tensor.py +++ b/test/prototype/mx_formats/test_nvfp4_tensor.py @@ -21,6 +21,7 @@ per_tensor_amax_to_scale, unpack_uint4, ) +from torchao.prototype.mx_formats.utils import ceil_div from torchao.quantization.utils import compute_error from torchao.testing.utils import skip_if_rocm from torchao.utils import ( @@ -127,18 +128,18 @@ def test_nvfp4_swizzled_scales_construction(is_swizzled_scales, shape): "slice_dim,slice_spec", [ # Row slicing - must align with 128-row boundaries - pytest.param(0, slice(0, 128), id="slice_rows[0:128]"), - pytest.param(0, slice(128, 256), id="slice_rows[128:256]"), + # pytest.param(0, slice(0, 128), id="slice_rows[0:128]"), + # pytest.param(0, slice(128, 256), id="slice_rows[128:256]"), # Column slicing - must align with 64-column boundaries (4 scale columns * 16 block_size) pytest.param(1, slice(0, 64), id="slice_cols[0:64]"), - pytest.param(1, slice(64, 128), id="slice_cols[64:128]"), - pytest.param(1, slice(0, 128), id="slice_cols[0:128]_full_width"), + # pytest.param(1, slice(64, 128), id="slice_cols[64:128]"), + # pytest.param(1, slice(0, 128), id="slice_cols[0:128]_full_width"), # Test tensor parallelism patterns (half splits) - pytest.param(1, slice(0, 2048), id="slice_cols[0:2048]_tp_first_half"), - pytest.param(1, slice(2048, 4096), id="slice_cols[2048:4096]_tp_second_half"), + # pytest.param(1, slice(0, 2048), id="slice_cols[0:2048]_tp_first_half"), + # pytest.param(1, slice(2048, 4096), id="slice_cols[2048:4096]_tp_second_half"), # Test quarter splits - pytest.param(1, slice(0, 1024), id="slice_cols[0:1024]_quarter"), - pytest.param(1, slice(1024, 2048), id="slice_cols[1024:2048]_quarter"), + # pytest.param(1, slice(0, 1024), id="slice_cols[0:1024]_quarter"), + # pytest.param(1, slice(1024, 2048), id="slice_cols[1024:2048]_quarter"), ], ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -157,21 +158,54 @@ def test_nvfp4_swizzled_scales_slicing(slice_dim, slice_spec): M, K = 256, 4096 else: # For column slicing, need multiples of 64 columns for alignment - M, K = 128, 4096 + # M, K = 128, 4096 + M, K = 128, 64 * 2 data = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) tensor = NVFP4Tensor.to_nvfp4(data, is_swizzled_scales=True) + # tensor.to_dtype(torch.bfloat16) assert tensor._is_swizzled_scales == True + print( + "before", + tensor.shape, + tensor.qdata.shape, + tensor._scale_e4m3.shape, + ) + # print(tensor._scale_e4m3[0:128, 0:4]) if slice_dim == 0: sliced_tensor = tensor[slice_spec, :] else: sliced_tensor = tensor[:, slice_spec] + print( + "after", + sliced_tensor.shape, + sliced_tensor.qdata.shape, + sliced_tensor._scale_e4m3.shape, + ) + # print(sliced_tensor._scale_e4m3[0:128, 0:4]) + # print(sliced_tensor.qdata.float() - tensor.qdata[0:128, 0:32].float()) + # print(sliced_tensor._scale_e4m3.float() - tensor._scale_e4m3[0:128, 0:4].float()) # Verify sliced tensor maintains swizzled state assert sliced_tensor._is_swizzled_scales == True + # this matches sliced_reconstructed, but not original_reconstructed[:, slice_spec] + if False: + sliced_manually = NVFP4Tensor( + tensor.qdata[:, 0:32], + tensor._scale_e4m3[:, 0:4].contiguous(), + tensor._block_size, + tensor._orig_dtype, + tensor._per_tensor_scale, + tensor._act_per_tensor_scale, + tensor._is_swizzled_scales, + tensor.use_triton_kernel, + tensor.act_quant_kwargs, + ) + import pdb; pdb.set_trace() + # Verify sliced tensor can be dequantized sliced_reconstructed = sliced_tensor.to_dtype(torch.bfloat16) @@ -181,6 +215,11 @@ def test_nvfp4_swizzled_scales_slicing(slice_dim, slice_spec): expected = original_reconstructed[slice_spec, :] else: expected = original_reconstructed[:, slice_spec] + print('e', expected) + print('s', sliced_reconstructed) + print('e - s', expected - sliced_reconstructed) + print(1, expected.abs().sum()) + print(2, sliced_reconstructed.abs().sum()) torch.testing.assert_close(sliced_reconstructed, expected, atol=1e-6, rtol=1e-6) @@ -421,7 +460,8 @@ def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype): @pytest.mark.parametrize("compile", [False]) @pytest.mark.parametrize("bias", [True, False]) @pytest.mark.parametrize("inpt_dtype", [torch.bfloat16, torch.float32]) -@pytest.mark.parametrize("use_triton_kernel", [True, False]) +# @pytest.mark.parametrize("use_triton_kernel", [True, False]) +@pytest.mark.parametrize("use_triton_kernel", [False]) @pytest.mark.parametrize( "shapes", [ @@ -525,3 +565,67 @@ def test_nvfp4_to_copy(): assert x.act_quant_kwargs == y.act_quant_kwargs assert x.dtype == torch.float32 assert y.dtype == torch.bfloat16 + + +@pytest.mark.parametrize("transpose", [False, True]) +# @pytest.mark.parametrize("transpose", [True]) +# @pytest.mark.parametrize("transpose", [False]) +@pytest.mark.parametrize("use_triton_kernel", [False, True]) +# @pytest.mark.parametrize("use_triton_kernel", [False]) +# @pytest.mark.parametrize("use_triton_kernel", [True]) +@pytest.mark.parametrize("is_swizzled_scales", [False, True]) +# @pytest.mark.parametrize("is_swizzled_scales", [True]) +@pytest.mark.parametrize( + "mk", + ( + (128, 64), + (128 + 16, 64), + (128, 64 + 16), + (128 + 16, 64 + 16), + ), +) +# @pytest.mark.parametrize("mk", ((128 + 16, 64),)) +def test_scale_shape_matches_qdata( + transpose, use_triton_kernel, is_swizzled_scales, mk +): + if use_triton_kernel and not is_swizzled_scales: + pytest.skip("triton kernel requires swizzled scales") + + M, K = mk + + block_size = 16 + + # TODO(this PR): test larger tensors that don't exactly map to (128, 64) tiles, + # to test the padding logic + # context: https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + x_hp = torch.randn(M, K, device="cuda") + x = NVFP4Tensor.to_nvfp4( + x_hp, is_swizzled_scales=is_swizzled_scales, use_triton_kernel=use_triton_kernel + ) + + m_dim, k_dim = 0, 1 + if transpose: + x_hp = x_hp.t() + x = x.t() + m_dim, k_dim = 1, 0 + + orig_m = x_hp.shape[m_dim] + expected_padded_m = orig_m + if is_swizzled_scales: + # in swizzled nvfp4, a 128x128 data unpacked / 128x64 data packed maps to a 32x16 scale tile + expected_padded_m = ceil_div(orig_m, 128) * 32 + actual_padded_m = x._scale_e4m3.shape[m_dim] + assert expected_padded_m == actual_padded_m, ( + f"incompatible padded shape for dim {m_dim}: {x.shape} and {x._scale_e4m3.shape}" + ) + + orig_k = x_hp.shape[k_dim] + expected_padded_k = orig_k // block_size + if is_swizzled_scales: + # in swizzled nvfp4, a 128x128 data unpacked / 128x64 data packed maps to a 32x16 scale tile + expected_padded_k = ceil_div(orig_k // block_size, 4) * 16 + actual_padded_k = x._scale_e4m3.shape[k_dim] + + assert expected_padded_k == actual_padded_k, ( + f"incompatible padded shape for dim {k_dim}: {x.shape} and {x._scale_e4m3.shape}" + ) diff --git a/torchao/prototype/mx_formats/nvfp4_tensor.py b/torchao/prototype/mx_formats/nvfp4_tensor.py index c22f7793bb..fbd1abb1d4 100644 --- a/torchao/prototype/mx_formats/nvfp4_tensor.py +++ b/torchao/prototype/mx_formats/nvfp4_tensor.py @@ -176,6 +176,12 @@ def to_nvfp4( f"Triton kernel requires K (dim 1) to be divisible by 16, got {data_hp.shape[1]}" ) blockwise_scales, data_lp = triton_quantize_nvfp4(data_hp, per_tensor_scale) + + # TODO(before land): share code for scale shape manipulation in the two + # if branches + scale_M = ceil_div(data_hp.shape[0], 128) * 32 + scale_K = ceil_div(data_hp.shape[1] // 16, 4) * 16 + blockwise_scales = blockwise_scales.view(scale_M, scale_K) else: blockwise_scales, data_lp = nvfp4_quantize( data_hp, block_size, per_tensor_scale @@ -183,9 +189,18 @@ def to_nvfp4( if is_swizzled_scales: M, K = data_hp.shape[0], data_hp.shape[1] scale_shape = (M, K // block_size) - blockwise_scales = to_blocked( - blockwise_scales.view(scale_shape) - ).flatten() + # print(1, blockwise_scales.shape) + blockwise_scales = blockwise_scales.view(scale_shape) + # print(2, blockwise_scales.shape, blockwise_scales) + blockwise_scales = to_blocked(blockwise_scales) + print(3, blockwise_scales.shape, blockwise_scales) + + # match shape of data_hp + # in swizzled nvfp4, a 128x128 data unpacked / 128x64 data packed maps to a 32x16 scale tile + scale_M = ceil_div(data_hp.shape[0], 128) * 32 + scale_K = ceil_div(data_hp.shape[1] // 16, 4) * 16 + blockwise_scales = blockwise_scales.view(scale_M, scale_K) + print(4, blockwise_scales.shape, blockwise_scales) return NVFP4Tensor( data_lp, @@ -212,6 +227,7 @@ def to_dtype(self, target_dtype: torch.dtype) -> torch.Tensor: torch.Tensor: Dequantized tensor in the target dtype """ is_transposed = self.qdata.stride(0) < self.qdata.stride(1) + print('is_transposed', is_transposed) if is_transposed: M, K = self.shape[1], self.shape[0] else: @@ -220,8 +236,12 @@ def to_dtype(self, target_dtype: torch.dtype) -> torch.Tensor: data_unpacked = unpack_uint4(data.contiguous().view(torch.uint8)) data_f32 = f4_unpacked_to_f32(data_unpacked) + # next: debug scale shape here data_f32 = data_f32.view(M, K // self._block_size, self._block_size) - scale_e4m3_reshaped = self.get_hp_scales().view(M, K // self._block_size, 1) + scales = self.get_hp_scales() + scales_tmp = scales.reshape(32, -1) + print('scales', scales_tmp.shape, scales_tmp[0:8]) + scale_e4m3_reshaped = scales.view(M, K // self._block_size, 1) data_scaled = data_f32 * scale_e4m3_reshaped.to(torch.float32) result = data_scaled.view(M, K).to(target_dtype) @@ -237,15 +257,17 @@ def get_hp_scales(self) -> torch.Tensor: torch.Tensor: Scales of the NVFP4Tensor """ is_transposed = self.qdata.stride(0) < self.qdata.stride(1) + print("is_transposed", is_transposed) if is_transposed: M, K = self.shape[1], self.shape[0] + scale_e4m3 = self._scale_e4m3.t() else: M, K = self.shape[0], self.shape[1] + scale_e4m3 = self._scale_e4m3 if self._is_swizzled_scales: - scale_e4m3 = from_blocked(self._scale_e4m3, M, K // self._block_size) - else: - scale_e4m3 = self._scale_e4m3 + # import pdb; pdb.set_trace() + scale_e4m3 = from_blocked(scale_e4m3, M, K // self._block_size) return ( scale_e4m3.to(self._orig_dtype) @@ -366,6 +388,7 @@ def nvfp4_slice(func, types, args, kwargs): raise ValueError("Only support aten.slice with step=1") assert x.qdata.is_contiguous(), "Only support contiguous data for now" + assert x._scale_e4m3.is_contiguous(), "Only support contiguous scale for now" M, K = x.shape[0], x.shape[1] @@ -376,6 +399,22 @@ def nvfp4_slice(func, types, args, kwargs): n_col_blocks = ceil_div(scale_cols, 4) elements_per_block = 32 * 16 # 512 elements + # + # See https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + # for qdata vs scale layout. Here is a summary specific to nvfp4: + # + # 1. qdata tile shape is (128, 32) packed fp4, which is (128, 64) unpacked fp4 + # 2. scale tile shape is (32, 16) + # 3. correspondence of qdata vs scale tiles is as follows, in a 2 by 2 tile example + # + # | tile_idx | qdata_rows | qdata_cols | scale_rows | scale_cols | + # ---------------------------------------------------------------- + # | 0 | 0:127 | 0:31 | 0:31 | 0:15 | + # | 1 | 128:255 | 0:31 | 32:63 | 0:15 | + # | 2 | 0:127 | 32:63 | 0:31 | 16:31 | + # | 3 | 128:255 | 32:63 | 32:63 | 16:31 | + # + if dim == 0: # Row slicing # Handle sys.maxsize (default slice end) @@ -407,7 +446,9 @@ def nvfp4_slice(func, types, args, kwargs): else None ) - sliced_scale = aten.slice.Tensor(x._scale_e4m3, 0, start_idx, end_idx, 1) + # TODO(before land): this is wrong, it works but need to express in terms of + # properly laid out scale as in the comment block above + sliced_scale = aten.slice.Tensor(x._scale_e4m3, 0, start, end, 1) sliced_data = aten.slice.Tensor(x.qdata, 0, start, end, step) elif dim == 1: @@ -452,20 +493,43 @@ def nvfp4_slice(func, types, args, kwargs): # Full width - no slicing needed sliced_scale = x._scale_e4m3 else: - # Extract specific column blocks from each row block - # Each row block in swizzled format contains n_col_blocks chunks of (32, 16) - elements_per_row_block = n_col_blocks * elements_per_block - - # Build list of slices to extract - slices_to_extract = [] - for row_block in range(n_row_blocks): - row_start = row_block * elements_per_row_block - col_start = row_start + start_col_block * elements_per_block - col_end = row_start + end_col_block * elements_per_block - slices_to_extract.append(x._scale_e4m3[col_start:col_end]) - - # Concatenate all the slices - sliced_scale = torch.cat(slices_to_extract, dim=0) + + scale = x._scale_e4m3 + + # reshape to expected shape + # TODO(before land): do this when swizzling so we don't have to do it here + + # TODO(before land): comment the mul by 2 here + scale_rows = n_row_blocks * 16 * 2 + scale_cols = n_col_blocks * 16 + scale = scale.view(scale_rows, scale_cols) + + # convert from hp_tensor row to scale row + start_scale_col = 0 if start is None else (start // 128 * 16) + end_scale_col = scale_cols if end is None or end >= K else (end // 16 * 4) + # import pdb; pdb.set_trace() + + if False: + # Extract specific column blocks from each row block + # Each row block in swizzled format contains n_col_blocks chunks of (32, 16) + elements_per_row_block = n_col_blocks * elements_per_block + + # Build list of slices to extract + slices_to_extract = [] + for row_block in range(n_row_blocks): + row_start = row_block * elements_per_row_block + col_start = row_start + start_col_block * elements_per_block + col_end = row_start + end_col_block * elements_per_block + slices_to_extract.append(x._scale_e4m3[col_start:col_end]) + + # Concatenate all the slices + sliced_scale = torch.cat(slices_to_extract, dim=0) + # import pdb; pdb.set_trace() + sliced_scale = aten.slice.Tensor( + # x._scale_e4m3, dim, start_scale_col, end_scale_col, step + scale, dim, start_scale_col, end_scale_col, step + ).contiguous() + # import pdb; pdb.set_trace() # Slice the data tensor packed_start = None if start is None else start // 2 @@ -537,7 +601,7 @@ def nvfp4_t(func, types, args, kwargs): old = args[0] new = NVFP4Tensor( old.qdata.t(), - old._scale_e4m3, + old._scale_e4m3.t(), old._block_size, old._orig_dtype, old._per_tensor_scale, @@ -577,6 +641,8 @@ def _addmm_nvfp4_dispatch( """ assert a.qdata.is_contiguous() assert b.qdata.t().is_contiguous() + assert a._scale_e4m3.is_contiguous() + assert b._scale_e4m3.t().is_contiguous() assert a._block_size == 16, f"NVFP4 requires block_size=16, got {a._block_size}" assert b._block_size == 16, f"NVFP4 requires block_size=16, got {b._block_size}" @@ -615,7 +681,7 @@ def _addmm_nvfp4_dispatch( a.qdata.view(torch.float4_e2m1fn_x2), b.qdata.view(torch.float4_e2m1fn_x2), a_scale_blocked.view(torch.float8_e4m3fn), - b_scale_blocked.view(torch.float8_e4m3fn), + b_scale_blocked.t().view(torch.float8_e4m3fn), bias=None if should_add_bias_separately else bias, out_dtype=a._orig_dtype, # scale_result=scale_result, # Not supported yet diff --git a/torchao/prototype/mx_formats/utils.py b/torchao/prototype/mx_formats/utils.py index 247b17d838..e8e619d7c5 100644 --- a/torchao/prototype/mx_formats/utils.py +++ b/torchao/prototype/mx_formats/utils.py @@ -26,7 +26,7 @@ def ceil_div(a, b): def to_blocked(input_matrix, use_triton_kernel: bool = False) -> Tensor: """ - Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern. + Rearrange a large matrix by breaking it into blocks, padding, and applying the rearrangement pattern. See: https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout @@ -37,7 +37,7 @@ def to_blocked(input_matrix, use_triton_kernel: bool = False) -> Tensor: torch.compile Returns: - Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4)) + Rearranged tensor of shape (32*ceil_div(H,128) * 16*ceil_div(W,4)) """ if use_triton_kernel: return triton_mx_block_rearrange(input_matrix).flatten() @@ -83,6 +83,7 @@ def from_blocked( """ n_row_blocks = ceil_div(original_rows, 128) n_col_blocks = ceil_div(original_cols, 4) + # import pdb; pdb.set_trace() rearranged = blocked_tensor.view(n_row_blocks * n_col_blocks, 32, 16)