From fa7b6e337b43edb1c7993962217309d65ed6b5bf Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 1 Jul 2025 18:56:11 -0700 Subject: [PATCH] Add all fbgemm kernel Tensors into Int4WeightOnlyConfig and Float8DynamicActivationInt4WeightConfig Summary: att, we will deprecate FbgemmConfig since it's a single kernel. we'd like to categorize things to derived dtype + packed format Test Plan: python test/quantization/quantize_/test_int4_groupwise_preshuffle.py Reviewers: Subscribers: Tasks: Tags: stack-info: PR: https://github.com/pytorch/ao/pull/2474, branch: jerryzh168/stack/10 --- test/integration/test_serialization_bc.py | 1 + .../test_int4_groupwise_preshuffle_tensor.py | 54 ++++------------ .../int4/test_int4_groupwise_tensor.py} | 21 ++---- torchao/dtypes/__init__.py | 3 - torchao/quantization/__init__.py | 2 + torchao/quantization/quant_api.py | 64 ++++++++++++++++++- torchao/quantization/quantize_/__init__.py | 2 + .../quantization/quantize_/int4/__init__.py | 4 ++ .../quantize_/int4/int4_groupwise_tensor.py} | 26 ++++---- 9 files changed, 101 insertions(+), 76 deletions(-) rename test/{dtypes/test_fbgemm_int4.py => quantization/quantize_/int4/test_int4_groupwise_tensor.py} (91%) rename torchao/{dtypes/fbgemm_int4_tensor.py => quantization/quantize_/int4/int4_groupwise_tensor.py} (93%) diff --git a/test/integration/test_serialization_bc.py b/test/integration/test_serialization_bc.py index 3c0082afc0..f665229bcc 100644 --- a/test/integration/test_serialization_bc.py +++ b/test/integration/test_serialization_bc.py @@ -18,6 +18,7 @@ _MODEL_NAMES = [ "torchao-testing/opt-125m-float8dq-row-fbgemm", + "torchao-testing/opt-125m-int4wo-preshuffle", ] diff --git a/test/quantization/quantize_/int4/test_int4_groupwise_preshuffle_tensor.py b/test/quantization/quantize_/int4/test_int4_groupwise_preshuffle_tensor.py index f120d4500b..5fda5ffc0a 100644 --- a/test/quantization/quantize_/int4/test_int4_groupwise_preshuffle_tensor.py +++ b/test/quantization/quantize_/int4/test_int4_groupwise_preshuffle_tensor.py @@ -15,9 +15,9 @@ run_tests, ) -from torchao.float8.config import e4m3_dtype from torchao.quantization import ( - FbgemmConfig, + Float8ActivationInt4WeightConfig, + Int4WeightOnlyConfig, quantize_, ) from torchao.quantization.utils import compute_error @@ -27,44 +27,16 @@ is_sm_at_least_90, ) -if TORCH_VERSION_AT_LEAST_2_8: - BF16_ACT_CONFIG = FbgemmConfig( - input_dtype=torch.bfloat16, - weight_dtype=torch.int4, - output_dtype=torch.bfloat16, - block_size=[1, 128], - preshuffle=True, - ) - - BF16_ACT_BMM_CONFIG = FbgemmConfig( - input_dtype=torch.bfloat16, - weight_dtype=torch.int4, - output_dtype=torch.bfloat16, - block_size=[1, 1, 128], - preshuffle=True, - ) - - FP8_ACT_CONFIG = FbgemmConfig( - input_dtype=e4m3_dtype, - weight_dtype=torch.int4, - output_dtype=torch.bfloat16, - block_size=[1, 128], - preshuffle=True, - ) - - FP8_ACT_BMM_CONFIG = FbgemmConfig( - input_dtype=e4m3_dtype, - weight_dtype=torch.int4, - output_dtype=torch.bfloat16, - block_size=[1, 1, 128], - preshuffle=True, - ) - -else: - BF16_ACT_CONFIG = None - BF16_ACT_BMM_CONFIG = None - FP8_ACT_CONFIG = None - FP8_ACT_BMM_CONFIG = None +BF16_ACT_CONFIG = Int4WeightOnlyConfig( + group_size=128, + use_preshuffle=True, +) + +FP8_ACT_CONFIG = Float8ActivationInt4WeightConfig( + group_size=128, + use_preshuffle=True, +) + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+") @@ -90,7 +62,7 @@ def test_linear(self, config): # Note: this order will error out: `Got bad cuda status: an illegal memory access was encountered at line: 449` # @parametrize("bmm_config", [BF16_ACT_BMM_CONFIG, FP8_ACT_BMM_CONFIG]) - @parametrize("bmm_config", [FP8_ACT_BMM_CONFIG, BF16_ACT_BMM_CONFIG]) + @parametrize("bmm_config", [FP8_ACT_CONFIG, BF16_ACT_CONFIG]) def test_bmm(self, bmm_config): class M(torch.nn.Module): def __init__(self, weight): diff --git a/test/dtypes/test_fbgemm_int4.py b/test/quantization/quantize_/int4/test_int4_groupwise_tensor.py similarity index 91% rename from test/dtypes/test_fbgemm_int4.py rename to test/quantization/quantize_/int4/test_int4_groupwise_tensor.py index eb1f059775..3a8ccf8932 100644 --- a/test/dtypes/test_fbgemm_int4.py +++ b/test/quantization/quantize_/int4/test_int4_groupwise_tensor.py @@ -13,7 +13,7 @@ ) from torchao.quantization import ( - FbgemmConfig, + Int4WeightOnlyConfig, quantize_, ) from torchao.quantization.utils import compute_error @@ -26,19 +26,12 @@ @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") -class TestFbgemmInt4Tensor(TestCase): +class TestInt4GroupwiseTensor(TestCase): def setUp(self): - self.config = FbgemmConfig( - input_dtype=torch.bfloat16, - weight_dtype=torch.int4, - output_dtype=torch.bfloat16, - block_size=[1, 128], - ) - self.bmm_config = FbgemmConfig( - input_dtype=torch.bfloat16, - weight_dtype=torch.int4, - output_dtype=torch.bfloat16, - block_size=[1, 1, 128], + self.config = Int4WeightOnlyConfig( + group_size=128, + use_preshuffle=False, + gemm_kernel_choice="fbgemm", ) self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else [] @@ -135,7 +128,7 @@ def forward(self, x): original = m(input) # we need to transpose the weight first for bmm m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous()) - quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True) + quantize_(m, self.config, filter_fn=lambda x, fqn: True) quantized = m(input) self.assertTrue(compute_error(original, quantized) > 18) diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index d6b1b9c440..575e154091 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -9,7 +9,6 @@ to_affine_quantized_intx_static, ) from .fbgemm_fp8_tensor import FbgemmFp8Tensor, to_fbgemm_fp8 -from .fbgemm_int4_tensor import FbgemmInt4Tensor, to_fbgemm_int4 from .floatx import ( CutlassSemiSparseLayout, Float8Layout, @@ -64,8 +63,6 @@ "PackedLinearInt8DynamicActivationIntxWeightLayout", "to_affine_quantized_packed_linear_int8_dynamic_activation_intx_weight", "Int4XPULayout", - "to_fbgemm_int4", - "FbgemmInt4Tensor", "to_fbgemm_fp8", "FbgemmFp8Tensor", "Int8DynamicActInt4WeightCPULayout", diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 6357328037..78155db7ee 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -44,6 +44,7 @@ from .quant_api import ( CutlassInt4PackedLayout, FbgemmConfig, + Float8ActivationInt4WeightConfig, Float8DynamicActivationFloat8SemiSparseWeightConfig, Float8DynamicActivationFloat8WeightConfig, Float8MMConfig, @@ -141,6 +142,7 @@ "Int8DynamicActivationInt8WeightConfig", "Int8DynamicActivationIntxWeightConfig", "Int4WeightOnlyConfig", + "Float8ActivationInt4WeightConfig", "Int8WeightOnlyConfig", "Float8WeightOnlyConfig", "Float8DynamicActivationFloat8WeightConfig", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 70890311fd..de796374eb 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -50,7 +50,6 @@ to_affine_quantized_floatx_static, to_affine_quantized_intx, to_fbgemm_fp8, - to_fbgemm_int4, to_marlinqqq_quantized_intx, ) from torchao.dtypes.uintx.packed_linear_int8_dynamic_activation_intx_weight_layout import ( @@ -73,6 +72,7 @@ from torchao.quantization.quantize_ import ( Float8Tensor, Int4GroupwisePreshuffleTensor, + Int4GroupwiseTensor, ) from torchao.quantization.transform_module import ( _QUANTIZE_CONFIG_HANDLER, @@ -1117,6 +1117,8 @@ class Int4WeightOnlyConfig(AOBaseConfig): zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.NONE set_inductor_config: bool = True preserve_zero: Optional[bool] = None + use_preshuffle: bool = False + gemm_kernel_choice: GemmKernelChoice = GemmKernelChoice.ATEN # for BC @@ -1134,6 +1136,8 @@ def _int4_weight_only_quantize_tensor(weight, config): layout = config.layout use_hqq = config.use_hqq zero_point_domain = config.zero_point_domain + use_preshuffle = config.use_preshuffle + gemm_kernel_choice = config.gemm_kernel_choice if weight.shape[-1] % group_size != 0: logger.info( @@ -1141,8 +1145,29 @@ def _int4_weight_only_quantize_tensor(weight, config): ) return weight + if use_preshuffle and gemm_kernel_choice != GemmKernelChoice.FBGEMM: + raise NotImplementedError( + f"use_preshuffle is only supported for fbgemm kernel, got: {gemm_kernel_choice}" + ) + + block_size = tuple([1 for _ in range(weight.ndim - 1)] + [group_size]) + + if gemm_kernel_choice == GemmKernelChoice.FBGEMM: + if use_preshuffle: + new_weight = Int4GroupwisePreshuffleTensor.from_float( + weight, + block_size, + activation_dtype="bf16", + ) + return new_weight + else: + new_weight = Int4GroupwiseTensor.from_float( + weight, + block_size, + ) + return new_weight + mapping_type = MappingType.ASYMMETRIC - block_size = tuple([1 for _ in range(weight.dim() - 1)] + [group_size]) target_dtype = torch.int32 quant_min = 0 quant_max = 15 @@ -1214,6 +1239,39 @@ def _int4_weight_only_transform( return module +@dataclass +class Float8ActivationInt4WeightConfig(AOBaseConfig): + group_size: int = 128 + use_preshuffle: bool = False + kernel: str = "fbgemm" + + +@register_quantize_module_handler(Float8ActivationInt4WeightConfig) +def _(module: torch.nn.Module, config: Int4WeightOnlyConfig) -> torch.nn.Module: + assert hasattr(module, "weight"), ( + "applying int8 weight only quant requires module to have weight attribute" + + " but {module} does not have one" + ) + group_size = config.group_size + use_preshuffle = config.use_preshuffle + kernel = config.kernel + + assert use_preshuffle, ( + f"only use_preshuffle == True is supported right now, got: {use_preshuffle}" + ) + assert kernel == "fbgemm", f"only fbgemm kernel is supported, got: {kernel}" + weight = module.weight + block_size = tuple([1 for _ in range(weight.ndim - 1)] + [group_size]) + new_weight = Int4GroupwisePreshuffleTensor.from_float( + module.weight, + block_size, + activation_dtype="fp8", + ) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module + + @dataclass class Int8WeightOnlyConfig(AOBaseConfig): """ @@ -2077,7 +2135,7 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module: activation_dtype="bf16", ) else: - weight = to_fbgemm_int4( + weight = Int4GroupwiseTensor.from_float( module.weight, config.block_size, ) diff --git a/torchao/quantization/quantize_/__init__.py b/torchao/quantization/quantize_/__init__.py index 3fcbc12013..cc6ec8cfaa 100644 --- a/torchao/quantization/quantize_/__init__.py +++ b/torchao/quantization/quantize_/__init__.py @@ -3,9 +3,11 @@ ) from .int4 import ( Int4GroupwisePreshuffleTensor, + Int4GroupwiseTensor, ) __all__ = [ "Int4GroupwisePreshuffleTensor", + "Int4GroupwiseTensor", "Float8Tensor", ] diff --git a/torchao/quantization/quantize_/int4/__init__.py b/torchao/quantization/quantize_/int4/__init__.py index 6ebbb55f0f..c5862ce403 100644 --- a/torchao/quantization/quantize_/int4/__init__.py +++ b/torchao/quantization/quantize_/int4/__init__.py @@ -1,7 +1,11 @@ from .int4_groupwise_preshuffle_tensor import ( Int4GroupwisePreshuffleTensor, ) +from .int4_groupwise_tensor import ( + Int4GroupwiseTensor, +) __all__ = [ "Int4GroupwisePreshuffleTensor", + "Int4GroupwiseTensor", ] diff --git a/torchao/dtypes/fbgemm_int4_tensor.py b/torchao/quantization/quantize_/int4/int4_groupwise_tensor.py similarity index 93% rename from torchao/dtypes/fbgemm_int4_tensor.py rename to torchao/quantization/quantize_/int4/int4_groupwise_tensor.py index 385f70e3bb..049fa14ec5 100644 --- a/torchao/dtypes/fbgemm_int4_tensor.py +++ b/torchao/quantization/quantize_/int4/int4_groupwise_tensor.py @@ -17,8 +17,7 @@ ) __all__ = [ - "to_fbgemm_int4", - "FbgemmInt4Tensor", + "Int4GroupwiseTensor", ] aten = torch.ops.aten @@ -31,7 +30,7 @@ pack_int4 = None -class FbgemmInt4Tensor(TorchAOBaseTensor): +class Int4GroupwiseTensor(TorchAOBaseTensor): tensor_data_attrs = ["packed_weight", "scale", "zero_point"] tensor_attributes = ["group_size", "shape"] @@ -118,7 +117,7 @@ def from_float( zero_point = zero_point.to(w.dtype) del w - return FbgemmInt4Tensor( + return Int4GroupwiseTensor( packed_weight=wq, scale=scale, zero_point=zero_point, @@ -127,7 +126,7 @@ def from_float( ) -implements = FbgemmInt4Tensor.implements +implements = Int4GroupwiseTensor.implements @implements([torch.nn.functional.linear, aten.linear.default]) @@ -143,8 +142,8 @@ def _(func, types, args, kwargs): res = torch.ops.fbgemm.bf16i4bf16_rowwise( input_tensor, weight_tensor.packed_weight.contiguous(), - weight_tensor.scale, - weight_tensor.zero_point, + weight_tensor.scale.contiguous(), + weight_tensor.zero_point.contiguous(), ) res = res.reshape(*orig_act_size[:-1], orig_out_features) if bias is not None: @@ -185,10 +184,10 @@ def _(func, types, args, kwargs): ) -def _same_metadata(self: "FbgemmInt4Tensor", src: "FbgemmInt4Tensor") -> bool: +def _same_metadata(self: "Int4GroupwiseTensor", src: "Int4GroupwiseTensor") -> bool: return ( - isinstance(self, FbgemmInt4Tensor) - and isinstance(src, FbgemmInt4Tensor) + isinstance(self, Int4GroupwiseTensor) + and isinstance(src, Int4GroupwiseTensor) and self.shape == src.shape and self.packed_weight.shape == src.packed_weight.shape and self.scale.shape == src.scale.shape @@ -287,9 +286,6 @@ def _(func, types, args, kwargs): return return_and_correct_aliasing(func, args, kwargs, new) -to_fbgemm_int4 = FbgemmInt4Tensor.from_float - - if TORCH_VERSION_AT_LEAST_2_5: - # Allow a model with FbgemmInt4Tensor weights to be loaded with `weights_only=True` - torch.serialization.add_safe_globals([FbgemmInt4Tensor]) + # Allow a model with Int4GroupwiseTensor weights to be loaded with `weights_only=True` + torch.serialization.add_safe_globals([Int4GroupwiseTensor])