Skip to content

Add all fbgemm kernel Tensors into Int4WeightOnlyConfig and Float8DynamicActivationInt4WeightConfig #2474

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: jerryzh168/stack/9
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/integration/test_serialization_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

_MODEL_NAMES = [
"torchao-testing/opt-125m-float8dq-row-fbgemm",
"torchao-testing/opt-125m-int4wo-preshuffle",
]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
run_tests,
)

from torchao.float8.config import e4m3_dtype
from torchao.quantization import (
FbgemmConfig,
Float8ActivationInt4WeightConfig,
Int4WeightOnlyConfig,
quantize_,
)
from torchao.quantization.utils import compute_error
Expand All @@ -27,44 +27,16 @@
is_sm_at_least_90,
)

if TORCH_VERSION_AT_LEAST_2_8:
BF16_ACT_CONFIG = FbgemmConfig(
input_dtype=torch.bfloat16,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, 128],
preshuffle=True,
)

BF16_ACT_BMM_CONFIG = FbgemmConfig(
input_dtype=torch.bfloat16,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, 1, 128],
preshuffle=True,
)

FP8_ACT_CONFIG = FbgemmConfig(
input_dtype=e4m3_dtype,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, 128],
preshuffle=True,
)

FP8_ACT_BMM_CONFIG = FbgemmConfig(
input_dtype=e4m3_dtype,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, 1, 128],
preshuffle=True,
)

else:
BF16_ACT_CONFIG = None
BF16_ACT_BMM_CONFIG = None
FP8_ACT_CONFIG = None
FP8_ACT_BMM_CONFIG = None
BF16_ACT_CONFIG = Int4WeightOnlyConfig(
group_size=128,
use_preshuffle=True,
)

FP8_ACT_CONFIG = Float8ActivationInt4WeightConfig(
group_size=128,
use_preshuffle=True,
)



@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+")
Expand All @@ -90,7 +62,7 @@ def test_linear(self, config):

# Note: this order will error out: `Got bad cuda status: an illegal memory access was encountered at line: 449`
# @parametrize("bmm_config", [BF16_ACT_BMM_CONFIG, FP8_ACT_BMM_CONFIG])
@parametrize("bmm_config", [FP8_ACT_BMM_CONFIG, BF16_ACT_BMM_CONFIG])
@parametrize("bmm_config", [FP8_ACT_CONFIG, BF16_ACT_CONFIG])
def test_bmm(self, bmm_config):
class M(torch.nn.Module):
def __init__(self, weight):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)

from torchao.quantization import (
FbgemmConfig,
Int4WeightOnlyConfig,
quantize_,
)
from torchao.quantization.utils import compute_error
Expand All @@ -26,19 +26,12 @@
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
class TestFbgemmInt4Tensor(TestCase):
class TestInt4GroupwiseTensor(TestCase):
def setUp(self):
self.config = FbgemmConfig(
input_dtype=torch.bfloat16,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, 128],
)
self.bmm_config = FbgemmConfig(
input_dtype=torch.bfloat16,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, 1, 128],
self.config = Int4WeightOnlyConfig(
group_size=128,
use_preshuffle=False,
gemm_kernel_choice="fbgemm",
)
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []

Expand Down Expand Up @@ -135,7 +128,7 @@ def forward(self, x):
original = m(input)
# we need to transpose the weight first for bmm
m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous())
quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True)
quantize_(m, self.config, filter_fn=lambda x, fqn: True)
quantized = m(input)
self.assertTrue(compute_error(original, quantized) > 18)

Expand Down
3 changes: 0 additions & 3 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
to_affine_quantized_intx_static,
)
from .fbgemm_fp8_tensor import FbgemmFp8Tensor, to_fbgemm_fp8
from .fbgemm_int4_tensor import FbgemmInt4Tensor, to_fbgemm_int4
from .floatx import (
CutlassSemiSparseLayout,
Float8Layout,
Expand Down Expand Up @@ -64,8 +63,6 @@
"PackedLinearInt8DynamicActivationIntxWeightLayout",
"to_affine_quantized_packed_linear_int8_dynamic_activation_intx_weight",
"Int4XPULayout",
"to_fbgemm_int4",
"FbgemmInt4Tensor",
"to_fbgemm_fp8",
"FbgemmFp8Tensor",
"Int8DynamicActInt4WeightCPULayout",
Expand Down
2 changes: 2 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from .quant_api import (
CutlassInt4PackedLayout,
FbgemmConfig,
Float8ActivationInt4WeightConfig,
Float8DynamicActivationFloat8SemiSparseWeightConfig,
Float8DynamicActivationFloat8WeightConfig,
Float8MMConfig,
Expand Down Expand Up @@ -141,6 +142,7 @@
"Int8DynamicActivationInt8WeightConfig",
"Int8DynamicActivationIntxWeightConfig",
"Int4WeightOnlyConfig",
"Float8ActivationInt4WeightConfig",
"Int8WeightOnlyConfig",
"Float8WeightOnlyConfig",
"Float8DynamicActivationFloat8WeightConfig",
Expand Down
64 changes: 61 additions & 3 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
to_affine_quantized_floatx_static,
to_affine_quantized_intx,
to_fbgemm_fp8,
to_fbgemm_int4,
to_marlinqqq_quantized_intx,
)
from torchao.dtypes.uintx.packed_linear_int8_dynamic_activation_intx_weight_layout import (
Expand All @@ -73,6 +72,7 @@
from torchao.quantization.quantize_ import (
Float8Tensor,
Int4GroupwisePreshuffleTensor,
Int4GroupwiseTensor,
)
from torchao.quantization.transform_module import (
_QUANTIZE_CONFIG_HANDLER,
Expand Down Expand Up @@ -1117,6 +1117,8 @@ class Int4WeightOnlyConfig(AOBaseConfig):
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.NONE
set_inductor_config: bool = True
preserve_zero: Optional[bool] = None
use_preshuffle: bool = False
gemm_kernel_choice: GemmKernelChoice = GemmKernelChoice.ATEN


# for BC
Expand All @@ -1134,15 +1136,38 @@ def _int4_weight_only_quantize_tensor(weight, config):
layout = config.layout
use_hqq = config.use_hqq
zero_point_domain = config.zero_point_domain
use_preshuffle = config.use_preshuffle
gemm_kernel_choice = config.gemm_kernel_choice

if weight.shape[-1] % group_size != 0:
logger.info(
f"Skipping quantizing weight with int4 weight only quantization because the shape of weight {weight.shape} is not compatible with group_size {group_size}"
)
return weight

if use_preshuffle and gemm_kernel_choice != GemmKernelChoice.FBGEMM:
raise NotImplementedError(
f"use_preshuffle is only supported for fbgemm kernel, got: {gemm_kernel_choice}"
)

block_size = tuple([1 for _ in range(weight.ndim - 1)] + [group_size])

if gemm_kernel_choice == GemmKernelChoice.FBGEMM:
if use_preshuffle:
new_weight = Int4GroupwisePreshuffleTensor.from_float(
weight,
block_size,
activation_dtype="bf16",
)
return new_weight
else:
new_weight = Int4GroupwiseTensor.from_float(
weight,
block_size,
)
return new_weight

mapping_type = MappingType.ASYMMETRIC
block_size = tuple([1 for _ in range(weight.dim() - 1)] + [group_size])
target_dtype = torch.int32
quant_min = 0
quant_max = 15
Expand Down Expand Up @@ -1214,6 +1239,39 @@ def _int4_weight_only_transform(
return module


@dataclass
class Float8ActivationInt4WeightConfig(AOBaseConfig):
group_size: int = 128
use_preshuffle: bool = False
kernel: str = "fbgemm"


@register_quantize_module_handler(Float8ActivationInt4WeightConfig)
def _(module: torch.nn.Module, config: Int4WeightOnlyConfig) -> torch.nn.Module:
assert hasattr(module, "weight"), (
"applying int8 weight only quant requires module to have weight attribute"
+ " but {module} does not have one"
)
group_size = config.group_size
use_preshuffle = config.use_preshuffle
kernel = config.kernel

assert use_preshuffle, (
f"only use_preshuffle == True is supported right now, got: {use_preshuffle}"
)
assert kernel == "fbgemm", f"only fbgemm kernel is supported, got: {kernel}"
weight = module.weight
block_size = tuple([1 for _ in range(weight.ndim - 1)] + [group_size])
new_weight = Int4GroupwisePreshuffleTensor.from_float(
module.weight,
block_size,
activation_dtype="fp8",
)
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
module.extra_repr = types.MethodType(_linear_extra_repr, module)
return module


@dataclass
class Int8WeightOnlyConfig(AOBaseConfig):
"""
Expand Down Expand Up @@ -2077,7 +2135,7 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module:
activation_dtype="bf16",
)
else:
weight = to_fbgemm_int4(
weight = Int4GroupwiseTensor.from_float(
module.weight,
config.block_size,
)
Expand Down
2 changes: 2 additions & 0 deletions torchao/quantization/quantize_/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
)
from .int4 import (
Int4GroupwisePreshuffleTensor,
Int4GroupwiseTensor,
)

__all__ = [
"Int4GroupwisePreshuffleTensor",
"Int4GroupwiseTensor",
"Float8Tensor",
]
4 changes: 4 additions & 0 deletions torchao/quantization/quantize_/int4/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from .int4_groupwise_preshuffle_tensor import (
Int4GroupwisePreshuffleTensor,
)
from .int4_groupwise_tensor import (
Int4GroupwiseTensor,
)

__all__ = [
"Int4GroupwisePreshuffleTensor",
"Int4GroupwiseTensor",
]
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
)

__all__ = [
"to_fbgemm_int4",
"FbgemmInt4Tensor",
"Int4GroupwiseTensor",
]

aten = torch.ops.aten
Expand All @@ -31,7 +30,7 @@
pack_int4 = None


class FbgemmInt4Tensor(TorchAOBaseTensor):
class Int4GroupwiseTensor(TorchAOBaseTensor):
tensor_data_attrs = ["packed_weight", "scale", "zero_point"]
tensor_attributes = ["group_size", "shape"]

Expand Down Expand Up @@ -118,7 +117,7 @@ def from_float(
zero_point = zero_point.to(w.dtype)

del w
return FbgemmInt4Tensor(
return Int4GroupwiseTensor(
packed_weight=wq,
scale=scale,
zero_point=zero_point,
Expand All @@ -127,7 +126,7 @@ def from_float(
)


implements = FbgemmInt4Tensor.implements
implements = Int4GroupwiseTensor.implements


@implements([torch.nn.functional.linear, aten.linear.default])
Expand All @@ -143,8 +142,8 @@ def _(func, types, args, kwargs):
res = torch.ops.fbgemm.bf16i4bf16_rowwise(
input_tensor,
weight_tensor.packed_weight.contiguous(),
weight_tensor.scale,
weight_tensor.zero_point,
weight_tensor.scale.contiguous(),
weight_tensor.zero_point.contiguous(),
)
res = res.reshape(*orig_act_size[:-1], orig_out_features)
if bias is not None:
Expand Down Expand Up @@ -185,10 +184,10 @@ def _(func, types, args, kwargs):
)


def _same_metadata(self: "FbgemmInt4Tensor", src: "FbgemmInt4Tensor") -> bool:
def _same_metadata(self: "Int4GroupwiseTensor", src: "Int4GroupwiseTensor") -> bool:
return (
isinstance(self, FbgemmInt4Tensor)
and isinstance(src, FbgemmInt4Tensor)
isinstance(self, Int4GroupwiseTensor)
and isinstance(src, Int4GroupwiseTensor)
and self.shape == src.shape
and self.packed_weight.shape == src.packed_weight.shape
and self.scale.shape == src.scale.shape
Expand Down Expand Up @@ -287,9 +286,6 @@ def _(func, types, args, kwargs):
return return_and_correct_aliasing(func, args, kwargs, new)


to_fbgemm_int4 = FbgemmInt4Tensor.from_float


if TORCH_VERSION_AT_LEAST_2_5:
# Allow a model with FbgemmInt4Tensor weights to be loaded with `weights_only=True`
torch.serialization.add_safe_globals([FbgemmInt4Tensor])
# Allow a model with Int4GroupwiseTensor weights to be loaded with `weights_only=True`
torch.serialization.add_safe_globals([Int4GroupwiseTensor])
Loading