-
Notifications
You must be signed in to change notification settings - Fork 342
Add Int8Tensor for clearer interface #3038
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Introduce new tensor subclass API for int8 quantization with clearer interface. The main change can be summarized to the following: - Old: Complex affine transform (AffineQuantizedTensor) with separate layout handling - New: Direct int8 tensor with qdata, scale, and zero_point attributes Test plan: test/quantization/quantize_/workflows/int8/test_int8_tensor.py Future plan: Implement block-wise quantization using `block_size` parameter
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3038
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
||
|
||
# TODO: Implement block-wise quantization using block_size | ||
class Int8PlainInt8Tensor(TorchAOBaseTensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we can just use Int8Tensor
if it's plain, since that's the default
can you add a version 2 and expose this tensor through ao/torchao/quantization/quant_api.py Line 1497 in 8525185
ao/torchao/quantization/quant_api.py Line 1752 in 8525185
|
args[2] if len(args) > 2 else None, | ||
) | ||
|
||
if isinstance(input_tensor, Int8PlainInt8Tensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we also need to quantize input_tensor in this function now, please check
ao/torchao/quantization/quantize_/workflows/float8/float8_tensor.py
Lines 263 to 266 in 9d88c16
if act_quant_kwargs is not None: | |
input_tensor = _choose_quant_func_and_quantize_tensor( | |
input_tensor, act_quant_kwargs | |
) |
|
x_int32 = input_tensor.qdata.to(torch.int32) | ||
w_int32 = weight_tensor.qdata.to(torch.int32).t() | ||
|
||
result = torch.mm(x_int32.view(-1, x_int32.size(-1)), w_int32) | ||
scale = input_tensor.scale.view(-1, 1) * weight_tensor.scale.unsqueeze(0) | ||
result = result.to(scale.dtype) * scale | ||
result = result.view(*input_tensor.shape[:-1], -1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is not the same as
ao/torchao/dtypes/uintx/plain_layout.py
Lines 269 to 315 in 122b307
def _linear_int8_act_int8_weight_check(input_tensor, weight_tensor, bias): | |
return ( | |
isinstance(input_tensor, AffineQuantizedTensor) | |
and _aqt_is_int8_reduced_range(input_tensor) | |
and isinstance(weight_tensor, AffineQuantizedTensor) | |
and _aqt_is_int8(weight_tensor) | |
and input_tensor.dtype == weight_tensor.dtype | |
and isinstance(input_tensor._layout, PlainLayout) | |
and isinstance(weight_tensor._layout, PlainLayout) | |
) | |
def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias): | |
# | |
# 1. do the matrix form of dot(X_i, W_j) | |
# | |
# | |
# 2. rescale the output | |
# | |
# in cases with large matrices, y_dot_int32 can grow sufficiently | |
# large that y_dot_int32 * a float16 scale is greater than the maximum | |
# value of a float 16, (which results in a value of inf even if multiplying | |
# by the other scale would bring it within the expected range) | |
x_vals_int8 = input_tensor.tensor_impl.int_data | |
x_scales = input_tensor.tensor_impl.scale | |
w_vals_int8_t = weight_tensor.tensor_impl.int_data.contiguous().t() | |
w_scales = weight_tensor.tensor_impl.scale | |
tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) | |
x_scales_dtype = x_scales.dtype | |
# Cast fp16 scale to float to avoid overflow in int_scaled_matmul | |
intermediate_dtype = torch.float if x_scales_dtype == torch.half else x_scales_dtype | |
y_dot_scaled = int_scaled_matmul( | |
tmp, w_vals_int8_t, x_scales.reshape(-1, 1).to(intermediate_dtype) | |
) | |
y_dot_scaled = y_dot_scaled.to(x_scales_dtype) | |
y = (y_dot_scaled * w_scales).reshape( | |
*x_vals_int8.shape[:-1], y_dot_scaled.shape[-1] | |
) | |
# can downcast only at the very end | |
output_dtype = input_tensor.dtype | |
y = y.to(output_dtype) | |
if bias is not None: | |
y += bias | |
return y |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add a test to check the kernel that's used similar to
def test_expected_gpu_kernel_fbgemm(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add a test to check the kernel that's used similar to
def test_expected_gpu_kernel_fbgemm(self): as well?
Yes linked workflow should be better to prevent overhead, I will fix it.
result = result.to(scale.dtype) * scale | ||
result = result.view(*input_tensor.shape[:-1], -1) | ||
else: | ||
# FP × INT8 (static) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also this is the code for weight only quant I think:
ao/torchao/dtypes/uintx/plain_layout.py
Line 250 in 122b307
def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thanks for pointing it out.
block_size (Optional[list[int]]): block size for quantization granularity | ||
""" | ||
|
||
kernel_preference: KernelPreference = KernelPreference.AUTO |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems like no multiple kernel preferences right now right? if so, we can remove this for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can remove this flag, but how about adding TODO for real kernel preference? Keeping current structure might be helpful for it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't have different kernel options for this one I think
|
||
|
||
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") | ||
class TestInt8Tensor(TorchAOIntegrationTestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for test, maybe try to follow https://github.com/pytorch/ao/blob/main/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py for now and also add some tests for slicing?
ao/test/quantization/quantize_/workflows/float8/test_float8_tensor.py
Lines 158 to 216 in 8e2ca35
def test_slice(self, granularity): | |
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity) | |
dtype = torch.bfloat16 | |
device = "cuda" | |
dummy = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device) | |
dummy1 = torch.nn.Linear(256, 64, bias=False, dtype=dtype, device=device) | |
dummy1.weight = torch.nn.Parameter( | |
dummy.weight.narrow(0, 0, 64), requires_grad=False | |
) | |
dummy2 = torch.nn.Linear(128, 256, dtype=dtype, device=device) | |
dummy2.weight = torch.nn.Parameter( | |
dummy.weight.narrow(1, 0, 128), requires_grad=False | |
) | |
quantize_(dummy, config) | |
weight1 = dummy.weight.clone().narrow(0, 0, 64) | |
weight2 = dummy.weight.clone().narrow(1, 0, 128) | |
self.assertEqual( | |
weight1.qdata, | |
dummy.weight.qdata.narrow(0, 0, 64), | |
) | |
self.assertEqual( | |
weight2.qdata, | |
dummy.weight.qdata.narrow(1, 0, 128), | |
) | |
if isinstance(granularity, PerRow): | |
self.assertEqual( | |
weight1.scale, | |
dummy.weight.scale.narrow(0, 0, 64), | |
) | |
self.assertEqual( | |
weight2.scale, | |
dummy.weight.scale, | |
) | |
else: | |
self.assertEqual( | |
weight1.scale, | |
dummy.weight.scale, | |
) | |
self.assertEqual( | |
weight2.scale, | |
dummy.weight.scale, | |
) | |
# check for sliced weight, before and after float8 quantization | |
# does not differ too much | |
input = torch.randn(2, 256, dtype=dtype, device=device) | |
res_ref = dummy1(input) | |
dummy.weight = torch.nn.Parameter(weight1.contiguous(), requires_grad=False) | |
res = dummy(input) | |
sqnr = compute_error(res, res_ref) | |
self.assertTrue(sqnr > 25, f"sqnr: {sqnr}") | |
input = torch.randn(2, 128, dtype=dtype, device=device) | |
res_ref = dummy2(input) | |
dummy.weight = torch.nn.Parameter(weight2.contiguous(), requires_grad=False) | |
res = dummy(input) | |
sqnr = compute_error(res, res_ref) | |
self.assertTrue(sqnr > 15, f"sqnr: {sqnr}") |
def test_slice_preserves_aliasing(self, granularity): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes linked unit test is helpful for slicing (PerTensor
, PerRow
) test, but I didn't implemented granularity
in this PR yet for smaller PR size. Can I address it after this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think the slicing tests are specific to a granularity, you should be able to adapt it for the currently supported granularity I think
raise ValueError("Expected 2D tensor and block_size length 2") | ||
|
||
# Rounding function from high precision dtype | ||
scale = w.abs().max(dim=-1, keepdim=True)[0] / 127.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like block_size is not used? why is that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can checkout
ao/torchao/dtypes/uintx/plain_layout.py
Line 232 in 8c5c33e
def _linear_fp_act_int8_weight_check(input_tensor, weight_tensor, bias): |
also this should be using these quant primitive ops:
ao/torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py
Lines 79 to 97 in 8c5c33e
scale, zero_point = choose_qparams_affine( | |
input=preprocessed_w, | |
mapping_type=MappingType.SYMMETRIC, | |
block_size=block_size, | |
target_dtype=target_dtype, | |
quant_min=quant_min, | |
quant_max=quant_max, | |
eps=1e-6, | |
) | |
wq = quantize_affine( | |
input=preprocessed_w, | |
block_size=block_size, | |
scale=scale, | |
zero_point=zero_point, | |
output_dtype=target_dtype, | |
quant_min=quant_min, | |
quant_max=quant_max, | |
) |
ao/torchao/quantization/quant_api.py
Line 1566 in 8c5c33e
new_weight = to_affine_quantized_intx( |
ao/torchao/dtypes/affine_quantized_tensor.py
Line 325 in 8c5c33e
scale, zero_point = choose_qparams_affine( |
this might require a bit too much context, let me know if you would like us to take over
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, surely want to take over! Drafted this PR for those updates, but will look into it today (6 hours later)
btw, version 2 is updated at c53dad0 (version 1 is default)
Summary
Introduce new tensor subclass API for int8 quantization with clearer interface.
The main change can be summarized to the following:
AffineQuantizedTensor
) with separate layout handlingRelated Issue/PR: #3012 (comment) #2752
Test plan
test/quantization/quantize_/workflows/int8/test_int8_tensor.py