diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index 5dd8f56717df..cc5b5cc93d82 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -457,7 +457,7 @@ class TorchAoConfig(QuantizationConfigMixin): - Shorthands: `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`, `float8_e4m3_tensor`, `float8_e4m3_row`, - - **Floating point X-bit quantization:** + - **Floating point X-bit quantization:** (in torchao <= 0.14.1, not supported in torchao >= 0.15.0) - Full function names: `fpx_weight_only` - Shorthands: `fpX_eAwB`, where `X` is the number of bits (between `1` to `7`), `A` is the number of exponent bits and `B` is the number of mantissa bits. The constraint of `X == A + B + 1` must @@ -531,12 +531,18 @@ def post_init(self): TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method() if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys(): - is_floating_quant_type = self.quant_type.startswith("float") or self.quant_type.startswith("fp") - if is_floating_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9(): + is_floatx_quant_type = self.quant_type.startswith("fp") + is_float_quant_type = self.quant_type.startswith("float") or is_floatx_quant_type + if is_float_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9(): raise ValueError( f"Requested quantization type: {self.quant_type} is not supported on GPUs with CUDA capability <= 8.9. You " f"can check the CUDA capability of your GPU using `torch.cuda.get_device_capability()`." ) + elif is_floatx_quant_type and not is_torchao_version("<=", "0.14.1"): + raise ValueError( + f"Requested quantization type: {self.quant_type} is only supported in torchao <= 0.14.1. " + f"Please downgrade to torchao <= 0.14.1 to use this quantization type." + ) raise ValueError( f"Requested quantization type: {self.quant_type} is not supported or is an incorrect `quant_type` name. If you think the " @@ -622,7 +628,6 @@ def _get_torchao_quant_type_to_method(cls): float8_dynamic_activation_float8_weight, float8_static_activation_float8_weight, float8_weight_only, - fpx_weight_only, int4_weight_only, int8_dynamic_activation_int4_weight, int8_dynamic_activation_int8_weight, @@ -630,6 +635,8 @@ def _get_torchao_quant_type_to_method(cls): uintx_weight_only, ) + if is_torchao_version("<=", "0.14.1"): + from torchao.quantization import fpx_weight_only # TODO(aryan): Add a note on how to use PerAxis and PerGroup observers from torchao.quantization.observer import PerRow, PerTensor @@ -650,18 +657,21 @@ def generate_float8dq_types(dtype: torch.dtype): return types def generate_fpx_quantization_types(bits: int): - types = {} + if is_torchao_version("<=", "0.14.1"): + types = {} - for ebits in range(1, bits): - mbits = bits - ebits - 1 - types[f"fp{bits}_e{ebits}m{mbits}"] = partial(fpx_weight_only, ebits=ebits, mbits=mbits) + for ebits in range(1, bits): + mbits = bits - ebits - 1 + types[f"fp{bits}_e{ebits}m{mbits}"] = partial(fpx_weight_only, ebits=ebits, mbits=mbits) - non_sign_bits = bits - 1 - default_ebits = (non_sign_bits + 1) // 2 - default_mbits = non_sign_bits - default_ebits - types[f"fp{bits}"] = partial(fpx_weight_only, ebits=default_ebits, mbits=default_mbits) + non_sign_bits = bits - 1 + default_ebits = (non_sign_bits + 1) // 2 + default_mbits = non_sign_bits - default_ebits + types[f"fp{bits}"] = partial(fpx_weight_only, ebits=default_ebits, mbits=default_mbits) - return types + return types + else: + raise ValueError("Floating point X-bit quantization is not supported in torchao >= 0.15.0") INT4_QUANTIZATION_TYPES = { # int4 weight + bfloat16/float16 activation @@ -710,15 +720,15 @@ def generate_fpx_quantization_types(bits: int): **generate_float8dq_types(torch.float8_e4m3fn), # float8 weight + float8 activation (static) "float8_static_activation_float8_weight": float8_static_activation_float8_weight, - # For fpx, only x <= 8 is supported by default. Other dtypes can be explored by users directly - # fpx weight + bfloat16/float16 activation - **generate_fpx_quantization_types(3), - **generate_fpx_quantization_types(4), - **generate_fpx_quantization_types(5), - **generate_fpx_quantization_types(6), - **generate_fpx_quantization_types(7), } + if is_torchao_version("<=", "0.14.1"): + FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(3)) + FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(4)) + FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(5)) + FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(6)) + FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(7)) + UINTX_QUANTIZATION_DTYPES = { "uintx_weight_only": uintx_weight_only, "uint1wo": partial(uintx_weight_only, dtype=torch.uint1), diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index e6bfc2530a5a..7a8e3cc67877 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -256,9 +256,12 @@ def test_quantization(self): # Cutlass fails to initialize for below # ("float8dq_e4m3_row", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), # ===== - ("fp4", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])), - ("fp6", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])), ]) + if version.parse(importlib.metadata.version("torchao")) <= version.Version("0.14.1"): + QUANTIZATION_TYPES_TO_TEST.extend([ + ("fp4", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])), + ("fp6", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])), + ]) # fmt: on for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST: @@ -271,6 +274,34 @@ def test_quantization(self): ) self._test_quant_type(quantization_config, expected_slice, model_id) + @unittest.skip("Skipping floatx quantization tests") + def test_floatx_quantization(self): + for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]: + if TorchAoConfig._is_xpu_or_cuda_capability_atleast_8_9(): + if version.parse(importlib.metadata.version("torchao")) <= version.Version("0.14.1"): + quantization_config = TorchAoConfig(quant_type="fp4", modules_to_not_convert=["x_embedder"]) + self._test_quant_type( + quantization_config, + np.array( + [ + 0.4648, + 0.5195, + 0.5547, + 0.4180, + 0.4434, + 0.6445, + 0.4316, + 0.4531, + 0.5625, + ] + ), + model_id, + ) + else: + # Make sure the correct error is thrown + with self.assertRaisesRegex(ValueError, "Please downgrade"): + quantization_config = TorchAoConfig(quant_type="fp4", modules_to_not_convert=["x_embedder"]) + def test_int4wo_quant_bfloat16_conversion(self): """ Tests whether the dtype of model will be modified to bfloat16 for int4 weight-only quantization. @@ -794,8 +825,11 @@ def test_quantization(self): if TorchAoConfig._is_xpu_or_cuda_capability_atleast_8_9(): QUANTIZATION_TYPES_TO_TEST.extend([ ("float8wo_e4m3", np.array([0.0546, 0.0722, 0.1328, 0.0468, 0.0585, 0.1367, 0.0605, 0.0703, 0.1328, 0.0625, 0.0703, 0.1445, 0.0585, 0.0703, 0.1406, 0.0605, 0.3496, 0.7109, 0.4843, 0.4042, 0.7226, 0.5000, 0.4160, 0.7031, 0.4824, 0.3886, 0.6757, 0.4667, 0.3710, 0.6679, 0.4902, 0.4238])), - ("fp5_e3m1", np.array([0.0527, 0.0762, 0.1309, 0.0449, 0.0645, 0.1328, 0.0566, 0.0723, 0.125, 0.0566, 0.0703, 0.1328, 0.0566, 0.0742, 0.1348, 0.0566, 0.3633, 0.7617, 0.5273, 0.4277, 0.7891, 0.5469, 0.4375, 0.8008, 0.5586, 0.4336, 0.7383, 0.5156, 0.3906, 0.6992, 0.5156, 0.4375])), ]) + if version.parse(importlib.metadata.version("torchao")) <= version.Version("0.14.1"): + QUANTIZATION_TYPES_TO_TEST.extend([ + ("fp5_e3m1", np.array([0.0527, 0.0762, 0.1309, 0.0449, 0.0645, 0.1328, 0.0566, 0.0723, 0.125, 0.0566, 0.0703, 0.1328, 0.0566, 0.0742, 0.1348, 0.0566, 0.3633, 0.7617, 0.5273, 0.4277, 0.7891, 0.5469, 0.4375, 0.8008, 0.5586, 0.4336, 0.7383, 0.5156, 0.3906, 0.6992, 0.5156, 0.4375])), + ]) # fmt: on for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST: