Skip to content

Commit 838ac58

Browse files
committed
[wip] make nvfp4 scale shape match qdata
Summary: does not work yet, need to fix tests we'll need this to stitch MoE weights in vLLM Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 20a2ad3 ghstack-comment-id: 3348895157 Pull-Request: #3094
1 parent 19ae1ad commit 838ac58

File tree

3 files changed

+128
-28
lines changed

3 files changed

+128
-28
lines changed

test/prototype/mx_formats/test_nvfp4_tensor.py

Lines changed: 84 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
per_tensor_amax_to_scale,
2222
unpack_uint4,
2323
)
24+
from torchao.prototype.mx_formats.utils import ceil_div
2425
from torchao.quantization.utils import compute_error
2526
from torchao.testing.utils import skip_if_rocm
2627
from torchao.utils import (
@@ -127,8 +128,8 @@ def test_nvfp4_swizzled_scales_construction(is_swizzled_scales, shape):
127128
"slice_dim,slice_spec",
128129
[
129130
# Row slicing - must align with 128-row boundaries
130-
pytest.param(0, slice(0, 128), id="slice_rows[0:128]"),
131-
pytest.param(0, slice(128, 256), id="slice_rows[128:256]"),
131+
# pytest.param(0, slice(0, 128), id="slice_rows[0:128]"),
132+
# pytest.param(0, slice(128, 256), id="slice_rows[128:256]"),
132133
# Column slicing - must align with 64-column boundaries (4 scale columns * 16 block_size)
133134
pytest.param(1, slice(0, 64), id="slice_cols[0:64]"),
134135
pytest.param(1, slice(64, 128), id="slice_cols[64:128]"),
@@ -162,12 +163,29 @@ def test_nvfp4_swizzled_scales_slicing(slice_dim, slice_spec):
162163
data = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
163164

164165
tensor = NVFP4Tensor.to_nvfp4(data, is_swizzled_scales=True)
166+
# tensor.to_dtype(torch.bfloat16)
165167
assert tensor._is_swizzled_scales == True
166168

169+
print(
170+
"before",
171+
tensor.shape,
172+
tensor.qdata.shape,
173+
tensor._scale_e4m3.shape,
174+
tensor._scale_e4m3.is_contiguous(),
175+
)
176+
print(tensor._scale_e4m3[0:128, 0:8])
167177
if slice_dim == 0:
168178
sliced_tensor = tensor[slice_spec, :]
169179
else:
170180
sliced_tensor = tensor[:, slice_spec]
181+
print(
182+
"after",
183+
sliced_tensor.shape,
184+
sliced_tensor.qdata.shape,
185+
sliced_tensor._scale_e4m3.shape,
186+
tensor._scale_e4m3.is_contiguous(),
187+
)
188+
print(tensor._scale_e4m3[0:128, 0:8])
171189

172190
# Verify sliced tensor maintains swizzled state
173191
assert sliced_tensor._is_swizzled_scales == True
@@ -421,7 +439,8 @@ def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype):
421439
@pytest.mark.parametrize("compile", [False])
422440
@pytest.mark.parametrize("bias", [True, False])
423441
@pytest.mark.parametrize("inpt_dtype", [torch.bfloat16, torch.float32])
424-
@pytest.mark.parametrize("use_triton_kernel", [True, False])
442+
# @pytest.mark.parametrize("use_triton_kernel", [True, False])
443+
@pytest.mark.parametrize("use_triton_kernel", [False])
425444
@pytest.mark.parametrize(
426445
"shapes",
427446
[
@@ -525,3 +544,65 @@ def test_nvfp4_to_copy():
525544
assert x.act_quant_kwargs == y.act_quant_kwargs
526545
assert x.dtype == torch.float32
527546
assert y.dtype == torch.bfloat16
547+
548+
549+
@pytest.mark.parametrize("transpose", [False, True])
550+
# @pytest.mark.parametrize("transpose", [True])
551+
# @pytest.mark.parametrize("transpose", [False])
552+
@pytest.mark.parametrize("use_triton_kernel", [False, True])
553+
# @pytest.mark.parametrize("use_triton_kernel", [False])
554+
# @pytest.mark.parametrize("use_triton_kernel", [True])
555+
@pytest.mark.parametrize("is_swizzled_scales", [False, True])
556+
# @pytest.mark.parametrize("is_swizzled_scales", [True])
557+
@pytest.mark.parametrize(
558+
"mk",
559+
(
560+
(128, 64),
561+
(128 + 16, 64),
562+
(128, 64 + 16),
563+
(128 + 16, 64 + 16),
564+
),
565+
)
566+
# @pytest.mark.parametrize("mk", ((128 + 16, 64),))
567+
def test_scale_shape_matches_qdata(
568+
transpose, use_triton_kernel, is_swizzled_scales, mk
569+
):
570+
if use_triton_kernel and not is_swizzled_scales:
571+
pytest.skip("triton kernel requires swizzled scales")
572+
573+
M, K = mk
574+
575+
block_size = 16
576+
577+
# TODO(this PR): test larger tensors that don't exactly map to (128, 64) tiles,
578+
# to test the padding logic
579+
# context: https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
580+
x_hp = torch.randn(M, K, device="cuda")
581+
x = NVFP4Tensor.to_nvfp4(
582+
x_hp, is_swizzled_scales=is_swizzled_scales, use_triton_kernel=use_triton_kernel
583+
)
584+
585+
m_dim, k_dim = 0, 1
586+
if transpose:
587+
x_hp = x_hp.t()
588+
x = x.t()
589+
m_dim, k_dim = 1, 0
590+
591+
orig_m = x_hp.shape[m_dim]
592+
expected_padded_m = orig_m
593+
if is_swizzled_scales:
594+
expected_padded_m = ceil_div(orig_m, 128) * 128
595+
actual_padded_m = x._scale_e4m3.shape[m_dim]
596+
assert expected_padded_m == actual_padded_m, (
597+
f"incompatible padded shape for dim {m_dim}: {x.shape} and {x._scale_e4m3.shape}"
598+
)
599+
600+
orig_k = x_hp.shape[k_dim]
601+
expected_padded_k = orig_k // block_size
602+
if is_swizzled_scales:
603+
expected_padded_k = ceil_div(orig_k // block_size, 4) * 4
604+
actual_padded_k = x._scale_e4m3.shape[k_dim]
605+
606+
assert expected_padded_k == actual_padded_k, (
607+
f"incompatible padded shape for dim {k_dim}: {x.shape} and {x._scale_e4m3.shape}"
608+
)

torchao/prototype/mx_formats/nvfp4_tensor.py

Lines changed: 41 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,17 @@ def to_nvfp4(
183183
if is_swizzled_scales:
184184
M, K = data_hp.shape[0], data_hp.shape[1]
185185
scale_shape = (M, K // block_size)
186-
blockwise_scales = to_blocked(
187-
blockwise_scales.view(scale_shape)
188-
).flatten()
186+
# print(1, blockwise_scales.shape)
187+
blockwise_scales = blockwise_scales.view(scale_shape)
188+
# print(2, blockwise_scales.shape, blockwise_scales)
189+
blockwise_scales = to_blocked(blockwise_scales)
190+
# print(3, blockwise_scales.shape, blockwise_scales)
191+
192+
# match shape of data_hp
193+
scale_M = ceil_div(data_hp.shape[0], 128) * 128
194+
scale_K = ceil_div(data_hp.shape[1] // 16, 4) * 4
195+
blockwise_scales = blockwise_scales.view(scale_M, scale_K)
196+
# print(4, blockwise_scales.shape, blockwise_scales)
189197

190198
return NVFP4Tensor(
191199
data_lp,
@@ -220,6 +228,7 @@ def to_dtype(self, target_dtype: torch.dtype) -> torch.Tensor:
220228
data_unpacked = unpack_uint4(data.contiguous().view(torch.uint8))
221229
data_f32 = f4_unpacked_to_f32(data_unpacked)
222230

231+
# next: debug scale shape here
223232
data_f32 = data_f32.view(M, K // self._block_size, self._block_size)
224233
scale_e4m3_reshaped = self.get_hp_scales().view(M, K // self._block_size, 1)
225234
data_scaled = data_f32 * scale_e4m3_reshaped.to(torch.float32)
@@ -237,15 +246,17 @@ def get_hp_scales(self) -> torch.Tensor:
237246
torch.Tensor: Scales of the NVFP4Tensor
238247
"""
239248
is_transposed = self.qdata.stride(0) < self.qdata.stride(1)
249+
print("is_transposed", is_transposed)
240250
if is_transposed:
241251
M, K = self.shape[1], self.shape[0]
252+
scale_e4m3 = self._scale_e4m3.t()
242253
else:
243254
M, K = self.shape[0], self.shape[1]
255+
scale_e4m3 = self._scale_e4m3
244256

245257
if self._is_swizzled_scales:
246-
scale_e4m3 = from_blocked(self._scale_e4m3, M, K // self._block_size)
247-
else:
248-
scale_e4m3 = self._scale_e4m3
258+
# import pdb; pdb.set_trace()
259+
scale_e4m3 = from_blocked(scale_e4m3, M, K // self._block_size)
249260

250261
return (
251262
scale_e4m3.to(self._orig_dtype)
@@ -366,6 +377,7 @@ def nvfp4_slice(func, types, args, kwargs):
366377
raise ValueError("Only support aten.slice with step=1")
367378

368379
assert x.qdata.is_contiguous(), "Only support contiguous data for now"
380+
assert x._scale_e4m3.is_contiguous(), "Only support contiguous scale for now"
369381

370382
M, K = x.shape[0], x.shape[1]
371383

@@ -407,7 +419,7 @@ def nvfp4_slice(func, types, args, kwargs):
407419
else None
408420
)
409421

410-
sliced_scale = aten.slice.Tensor(x._scale_e4m3, 0, start_idx, end_idx, 1)
422+
sliced_scale = aten.slice.Tensor(x._scale_e4m3, 0, start, end, 1)
411423
sliced_data = aten.slice.Tensor(x.qdata, 0, start, end, step)
412424

413425
elif dim == 1:
@@ -452,20 +464,24 @@ def nvfp4_slice(func, types, args, kwargs):
452464
# Full width - no slicing needed
453465
sliced_scale = x._scale_e4m3
454466
else:
455-
# Extract specific column blocks from each row block
456-
# Each row block in swizzled format contains n_col_blocks chunks of (32, 16)
457-
elements_per_row_block = n_col_blocks * elements_per_block
458-
459-
# Build list of slices to extract
460-
slices_to_extract = []
461-
for row_block in range(n_row_blocks):
462-
row_start = row_block * elements_per_row_block
463-
col_start = row_start + start_col_block * elements_per_block
464-
col_end = row_start + end_col_block * elements_per_block
465-
slices_to_extract.append(x._scale_e4m3[col_start:col_end])
466-
467-
# Concatenate all the slices
468-
sliced_scale = torch.cat(slices_to_extract, dim=0)
467+
if False:
468+
# Extract specific column blocks from each row block
469+
# Each row block in swizzled format contains n_col_blocks chunks of (32, 16)
470+
elements_per_row_block = n_col_blocks * elements_per_block
471+
472+
# Build list of slices to extract
473+
slices_to_extract = []
474+
for row_block in range(n_row_blocks):
475+
row_start = row_block * elements_per_row_block
476+
col_start = row_start + start_col_block * elements_per_block
477+
col_end = row_start + end_col_block * elements_per_block
478+
slices_to_extract.append(x._scale_e4m3[col_start:col_end])
479+
480+
# Concatenate all the slices
481+
sliced_scale = torch.cat(slices_to_extract, dim=0)
482+
sliced_scale = aten.slice.Tensor(
483+
x._scale_e4m3, dim, start_scale_col, end_scale_col, step
484+
).contiguous()
469485

470486
# Slice the data tensor
471487
packed_start = None if start is None else start // 2
@@ -537,7 +553,7 @@ def nvfp4_t(func, types, args, kwargs):
537553
old = args[0]
538554
new = NVFP4Tensor(
539555
old.qdata.t(),
540-
old._scale_e4m3,
556+
old._scale_e4m3.t(),
541557
old._block_size,
542558
old._orig_dtype,
543559
old._per_tensor_scale,
@@ -577,6 +593,8 @@ def _addmm_nvfp4_dispatch(
577593
"""
578594
assert a.qdata.is_contiguous()
579595
assert b.qdata.t().is_contiguous()
596+
assert a._scale_e4m3.is_contiguous()
597+
assert b._scale_e4m3.t().is_contiguous()
580598
assert a._block_size == 16, f"NVFP4 requires block_size=16, got {a._block_size}"
581599
assert b._block_size == 16, f"NVFP4 requires block_size=16, got {b._block_size}"
582600

@@ -615,7 +633,7 @@ def _addmm_nvfp4_dispatch(
615633
a.qdata.view(torch.float4_e2m1fn_x2),
616634
b.qdata.view(torch.float4_e2m1fn_x2),
617635
a_scale_blocked.view(torch.float8_e4m3fn),
618-
b_scale_blocked.view(torch.float8_e4m3fn),
636+
b_scale_blocked.t().view(torch.float8_e4m3fn),
619637
bias=None if should_add_bias_separately else bias,
620638
out_dtype=a._orig_dtype,
621639
# scale_result=scale_result, # Not supported yet

torchao/prototype/mx_formats/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def ceil_div(a, b):
2626

2727
def to_blocked(input_matrix, use_triton_kernel: bool = False) -> Tensor:
2828
"""
29-
Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern.
29+
Rearrange a large matrix by breaking it into blocks, padding, and applying the rearrangement pattern.
3030
3131
See:
3232
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
@@ -37,7 +37,7 @@ def to_blocked(input_matrix, use_triton_kernel: bool = False) -> Tensor:
3737
torch.compile
3838
3939
Returns:
40-
Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4))
40+
Rearranged tensor of shape (32*ceil_div(H,128) * 16*ceil_div(W,4))
4141
"""
4242
if use_triton_kernel:
4343
return triton_mx_block_rearrange(input_matrix).flatten()
@@ -83,6 +83,7 @@ def from_blocked(
8383
"""
8484
n_row_blocks = ceil_div(original_rows, 128)
8585
n_col_blocks = ceil_div(original_cols, 4)
86+
# import pdb; pdb.set_trace()
8687

8788
rearranged = blocked_tensor.view(n_row_blocks * n_col_blocks, 32, 16)
8889

0 commit comments

Comments
 (0)