Skip to content

Conversation

namgyu-youn
Copy link
Contributor

@namgyu-youn namgyu-youn commented Sep 21, 2025

Summary

Introduce new tensor subclass API for int8 quantization with clearer interface.

The main change can be summarized to the following:

  • Old: Complex affine transform (AffineQuantizedTensor) with separate layout handling
  • New: Direct int8 tensor with scaling factor and zero point

Related Issue/PR: #3012 (comment) #2752

Test plan

test/quantization/quantize_/workflows/int8/test_int8_tensor.py

Introduce new tensor subclass API for int8 quantization with clearer interface.

The main change can be summarized to the following:
- Old: Complex affine transform (AffineQuantizedTensor) with separate layout handling
- New: Direct int8 tensor with qdata, scale, and zero_point attributes

Test plan:
test/quantization/quantize_/workflows/int8/test_int8_tensor.py

Future plan:
Implement block-wise quantization using `block_size` parameter
Copy link

pytorch-bot bot commented Sep 21, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3038

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 21, 2025


# TODO: Implement block-wise quantization using block_size
class Int8PlainInt8Tensor(TorchAOBaseTensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we can just use Int8Tensor if it's plain, since that's the default

@jerryzh168
Copy link
Contributor

can you add a version 2 and expose this tensor through

class Int8DynamicActivationInt8WeightConfig(AOBaseConfig):
? similar to
class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):

args[2] if len(args) > 2 else None,
)

if isinstance(input_tensor, Int8PlainInt8Tensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we also need to quantize input_tensor in this function now, please check

if act_quant_kwargs is not None:
input_tensor = _choose_quant_func_and_quantize_tensor(
input_tensor, act_quant_kwargs
)

@namgyu-youn namgyu-youn changed the title Add Int8PlainInt8Tensor for clearer interface Add Int8Tensor for clearer interface Sep 23, 2025
@namgyu-youn
Copy link
Contributor Author

namgyu-youn commented Sep 23, 2025

can you add a version 2 and expose this tensor through

class Int8DynamicActivationInt8WeightConfig(AOBaseConfig):

? similar to

class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):

Since I am not familiar with it, it might be delayed. May I divide PRs? Sure, I will look into it.

Comment on lines 176 to 182
x_int32 = input_tensor.qdata.to(torch.int32)
w_int32 = weight_tensor.qdata.to(torch.int32).t()

result = torch.mm(x_int32.view(-1, x_int32.size(-1)), w_int32)
scale = input_tensor.scale.view(-1, 1) * weight_tensor.scale.unsqueeze(0)
result = result.to(scale.dtype) * scale
result = result.view(*input_tensor.shape[:-1], -1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not the same as

def _linear_int8_act_int8_weight_check(input_tensor, weight_tensor, bias):
return (
isinstance(input_tensor, AffineQuantizedTensor)
and _aqt_is_int8_reduced_range(input_tensor)
and isinstance(weight_tensor, AffineQuantizedTensor)
and _aqt_is_int8(weight_tensor)
and input_tensor.dtype == weight_tensor.dtype
and isinstance(input_tensor._layout, PlainLayout)
and isinstance(weight_tensor._layout, PlainLayout)
)
def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias):
#
# 1. do the matrix form of dot(X_i, W_j)
#
#
# 2. rescale the output
#
# in cases with large matrices, y_dot_int32 can grow sufficiently
# large that y_dot_int32 * a float16 scale is greater than the maximum
# value of a float 16, (which results in a value of inf even if multiplying
# by the other scale would bring it within the expected range)
x_vals_int8 = input_tensor.tensor_impl.int_data
x_scales = input_tensor.tensor_impl.scale
w_vals_int8_t = weight_tensor.tensor_impl.int_data.contiguous().t()
w_scales = weight_tensor.tensor_impl.scale
tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1])
x_scales_dtype = x_scales.dtype
# Cast fp16 scale to float to avoid overflow in int_scaled_matmul
intermediate_dtype = torch.float if x_scales_dtype == torch.half else x_scales_dtype
y_dot_scaled = int_scaled_matmul(
tmp, w_vals_int8_t, x_scales.reshape(-1, 1).to(intermediate_dtype)
)
y_dot_scaled = y_dot_scaled.to(x_scales_dtype)
y = (y_dot_scaled * w_scales).reshape(
*x_vals_int8.shape[:-1], y_dot_scaled.shape[-1]
)
# can downcast only at the very end
output_dtype = input_tensor.dtype
y = y.to(output_dtype)
if bias is not None:
y += bias
return y
?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a test to check the kernel that's used similar to

def test_expected_gpu_kernel_fbgemm(self):
as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a test to check the kernel that's used similar to

def test_expected_gpu_kernel_fbgemm(self):

as well?

Yes linked workflow should be better to prevent overhead, I will fix it.

result = result.to(scale.dtype) * scale
result = result.view(*input_tensor.shape[:-1], -1)
else:
# FP × INT8 (static)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also this is the code for weight only quant I think:

def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks for pointing it out.

block_size (Optional[list[int]]): block size for quantization granularity
"""

kernel_preference: KernelPreference = KernelPreference.AUTO
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like no multiple kernel preferences right now right? if so, we can remove this for now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove this flag, but how about adding TODO for real kernel preference? Keeping current structure might be helpful for it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't have different kernel options for this one I think



@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
class TestInt8Tensor(TorchAOIntegrationTestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for test, maybe try to follow https://github.com/pytorch/ao/blob/main/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py for now and also add some tests for slicing?

def test_slice(self, granularity):
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
dtype = torch.bfloat16
device = "cuda"
dummy = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device)
dummy1 = torch.nn.Linear(256, 64, bias=False, dtype=dtype, device=device)
dummy1.weight = torch.nn.Parameter(
dummy.weight.narrow(0, 0, 64), requires_grad=False
)
dummy2 = torch.nn.Linear(128, 256, dtype=dtype, device=device)
dummy2.weight = torch.nn.Parameter(
dummy.weight.narrow(1, 0, 128), requires_grad=False
)
quantize_(dummy, config)
weight1 = dummy.weight.clone().narrow(0, 0, 64)
weight2 = dummy.weight.clone().narrow(1, 0, 128)
self.assertEqual(
weight1.qdata,
dummy.weight.qdata.narrow(0, 0, 64),
)
self.assertEqual(
weight2.qdata,
dummy.weight.qdata.narrow(1, 0, 128),
)
if isinstance(granularity, PerRow):
self.assertEqual(
weight1.scale,
dummy.weight.scale.narrow(0, 0, 64),
)
self.assertEqual(
weight2.scale,
dummy.weight.scale,
)
else:
self.assertEqual(
weight1.scale,
dummy.weight.scale,
)
self.assertEqual(
weight2.scale,
dummy.weight.scale,
)
# check for sliced weight, before and after float8 quantization
# does not differ too much
input = torch.randn(2, 256, dtype=dtype, device=device)
res_ref = dummy1(input)
dummy.weight = torch.nn.Parameter(weight1.contiguous(), requires_grad=False)
res = dummy(input)
sqnr = compute_error(res, res_ref)
self.assertTrue(sqnr > 25, f"sqnr: {sqnr}")
input = torch.randn(2, 128, dtype=dtype, device=device)
res_ref = dummy2(input)
dummy.weight = torch.nn.Parameter(weight2.contiguous(), requires_grad=False)
res = dummy(input)
sqnr = compute_error(res, res_ref)
self.assertTrue(sqnr > 15, f"sqnr: {sqnr}")
and
def test_slice_preserves_aliasing(self, granularity):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes linked unit test is helpful for slicing (PerTensor, PerRow) test, but I didn't implemented granularity in this PR yet for smaller PR size. Can I address it after this PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the slicing tests are specific to a granularity, you should be able to adapt it for the currently supported granularity I think

raise ValueError("Expected 2D tensor and block_size length 2")

# Rounding function from high precision dtype
scale = w.abs().max(dim=-1, keepdim=True)[0] / 127.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like block_size is not used? why is that?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can checkout

def _linear_fp_act_int8_weight_check(input_tensor, weight_tensor, bias):
for expected granularity

also this should be using these quant primitive ops:

scale, zero_point = choose_qparams_affine(
input=preprocessed_w,
mapping_type=MappingType.SYMMETRIC,
block_size=block_size,
target_dtype=target_dtype,
quant_min=quant_min,
quant_max=quant_max,
eps=1e-6,
)
wq = quantize_affine(
input=preprocessed_w,
block_size=block_size,
scale=scale,
zero_point=zero_point,
output_dtype=target_dtype,
quant_min=quant_min,
quant_max=quant_max,
)
, arguments can be found by tracing through the code path for int8 in
new_weight = to_affine_quantized_intx(
and
scale, zero_point = choose_qparams_affine(

this might require a bit too much context, let me know if you would like us to take over

Copy link
Contributor Author

@namgyu-youn namgyu-youn Sep 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, surely want to take over! Drafted this PR for those updates, but will look into it today (6 hours later)

btw, version 2 is updated at c53dad0 (version 1 is default)

@namgyu-youn namgyu-youn marked this pull request as draft September 28, 2025 13:23
@namgyu-youn namgyu-youn marked this pull request as ready for review September 30, 2025 06:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants