diff --git a/backends/apple/coreml/quantizer/_annotation_config.py b/backends/apple/coreml/quantizer/_annotation_config.py new file mode 100644 index 00000000000..f5c1c7ef938 --- /dev/null +++ b/backends/apple/coreml/quantizer/_annotation_config.py @@ -0,0 +1,130 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +from typing import Optional as _Optional + +import torch as _torch + +from attr import define as _define + +from coremltools.optimize.torch.quantization.quantization_config import ( + ModuleLinearQuantizerConfig as _ModuleLinearQuantizerConfig, + QuantizationScheme as _QuantizationScheme, +) + +from torchao.quantization.pt2e.fake_quantize import FakeQuantize as _FakeQuantize + +from torchao.quantization.pt2e.observer import ( + MinMaxObserver as _MinMaxObserver, + MovingAverageMinMaxObserver as _MovingAverageMinMaxObserver, + MovingAveragePerChannelMinMaxObserver as _MovingAveragePerChannelMinMaxObserver, + PerChannelMinMaxObserver as _PerChannelMinMaxObserver, +) +from torchao.quantization.pt2e.quantizer import ( + QuantizationSpec as _TorchQuantizationSpec, +) + + +def _get_observer(observer_type, is_per_channel: bool): + _str_to_observer_map = { + "moving_average_min_max": _MovingAverageMinMaxObserver, + "min_max": _MinMaxObserver, + "moving_average_min_max_per_channel": _MovingAveragePerChannelMinMaxObserver, + "min_max_per_channel": _PerChannelMinMaxObserver, + } + observer_name = observer_type.value + if is_per_channel: + observer_name = f"{observer_name}_per_channel" + if observer_name not in _str_to_observer_map: + raise ValueError(f"Unsupported observer type: {observer_name}") + return _str_to_observer_map[observer_name] + + +@_define +class AnnotationConfig: + """ + Module/Operator level configuration class for :py:class:`CoreMLQuantizer`. + + For each module/operator, defines the dtype, quantization scheme and observer type + for input(s), output and weights (if any). + """ + + input_activation: _Optional[_TorchQuantizationSpec] = None + output_activation: _Optional[_TorchQuantizationSpec] = None + weight: _Optional[_TorchQuantizationSpec] = None + + @staticmethod + def _normalize_dtype(dtype: _torch.dtype) -> _torch.dtype: + """ + PyTorch export quantizer only supports uint8 and int8 data types, + so we map the quantized dtypes to the corresponding supported dtype. + """ + dtype_map = { + _torch.quint8: _torch.uint8, + _torch.qint8: _torch.int8, + } + return dtype_map.get(dtype, dtype) + + @classmethod + def from_quantization_config( + cls, + quantization_config: _Optional[_ModuleLinearQuantizerConfig], + ) -> _Optional["AnnotationConfig"]: + """ + Creates a :py:class:`AnnotationConfig` from ``ModuleLinearQuantizerConfig`` + """ + if ( + quantization_config is None + or quantization_config.weight_dtype == _torch.float32 + ): + return None + + # Activation QSpec + if quantization_config.activation_dtype == _torch.float32: + output_activation_qspec = None + else: + activation_qscheme = _QuantizationScheme.get_qscheme( + quantization_config.quantization_scheme, + is_per_channel=False, + ) + activation_dtype = cls._normalize_dtype( + quantization_config.activation_dtype + ) + output_activation_qspec = _TorchQuantizationSpec( + observer_or_fake_quant_ctr=_FakeQuantize.with_args( + observer=_get_observer( + quantization_config.activation_observer, + is_per_channel=False, + ), + dtype=activation_dtype, + qscheme=activation_qscheme, + ), + dtype=activation_dtype, + qscheme=activation_qscheme, + ) + + # Weight QSpec + weight_qscheme = _QuantizationScheme.get_qscheme( + quantization_config.quantization_scheme, + is_per_channel=quantization_config.weight_per_channel, + ) + weight_dtype = cls._normalize_dtype(quantization_config.weight_dtype) + weight_qspec = _TorchQuantizationSpec( + observer_or_fake_quant_ctr=_FakeQuantize.with_args( + observer=_get_observer( + quantization_config.weight_observer, + is_per_channel=quantization_config.weight_per_channel, + ), + dtype=weight_dtype, + qscheme=weight_qscheme, + ), + dtype=weight_dtype, + qscheme=weight_qscheme, + ) + return AnnotationConfig( + input_activation=output_activation_qspec, + output_activation=output_activation_qspec, + weight=weight_qspec, + ) diff --git a/backends/apple/coreml/quantizer/_coreml_quantizer.py b/backends/apple/coreml/quantizer/_coreml_quantizer.py new file mode 100644 index 00000000000..c37e5cca463 --- /dev/null +++ b/backends/apple/coreml/quantizer/_coreml_quantizer.py @@ -0,0 +1,628 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +import operator as _operator +from typing import Callable as _Callable, List as _List, Optional as _Optional + +import executorch.backends.apple.coreml.quantizer._coreml_quantizer_utils as _annotation_utils + +import torch as _torch + +from coremltools.optimize.torch._utils.python_utils import ( + FunctionRegistryMixin as _FunctionRegistryMixin, +) + +from coremltools.optimize.torch.quantization.quantization_config import ( + LinearQuantizerConfig as _LinearQuantizerConfig, + ModuleLinearQuantizerConfig as _ModuleLinearQuantizerConfig, +) + +from executorch.backends.apple.coreml.quantizer._annotation_config import ( + AnnotationConfig as _AnnotationConfig, +) + +from torch.fx import Node as _Node +from torchao.quantization.pt2e.quantizer.quantizer import Quantizer as _TorchQuantizer + +from torchao.quantization.pt2e.quantizer.utils import get_module_name_filter + + +class _AnnotationPatternRegistry(_FunctionRegistryMixin): + """ + A registry of quantization annotation rules. + """ + + @classmethod + def get_annotators(cls): + return cls.REGISTRY + + +@_AnnotationPatternRegistry.register("conv_act") +def _annotate_conv_act( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates input -> conv -> activation -> output + """ + return _annotation_utils.annotate_conv_bn_act_helper( + model, quantization_config, filter_fn, use_bn=False + ) + + +@_AnnotationPatternRegistry.register("conv_bn_act") +def _annotate_conv_bn_act( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates input -> conv -> batch_norm -> activation -> output + """ + return _annotation_utils.annotate_conv_bn_act_helper( + model, quantization_config, filter_fn, use_bn=True + ) + + +@_AnnotationPatternRegistry.register("conv_bn") +def _annotate_conv_bn( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates input -> conv -> batch_norm -> output + """ + annotated_partitions = [] + + conv_dims = [1, 2, 3] + for conv_dim in conv_dims: + pattern_gm = _annotation_utils.get_conv_bn_pattern( + conv_dim, act_fn=None, act_in_place=False + ) + annotated_partitions.extend( + _annotation_utils.annotate_weighted_mod_pattern( + model, pattern_gm, quantization_config, filter_fn + ) + ) + return annotated_partitions + + +@_AnnotationPatternRegistry.register("conv") +def _annotate_conv( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates input -> conv -> output + """ + annotated_partitions = [] + for conv_dim in [1, 2, 3]: + pattern_gm = _annotation_utils.get_conv_pattern(conv_dim=conv_dim, act_fn=None) + annotated_partitions.extend( + _annotation_utils.annotate_weighted_mod_pattern( + model, pattern_gm, quantization_config, filter_fn + ) + ) + + return annotated_partitions + + +@_AnnotationPatternRegistry.register("linear_act") +def _annotate_linear_act( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates input -> linear -> activation -> output + """ + return _annotation_utils.annotate_linear_bn_act_helper( + model, quantization_config, filter_fn, use_bn=False + ) + + +@_AnnotationPatternRegistry.register("linear_bn_act") +def _annotate_linear_bn_act( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates input -> linear -> batch_norm -> activation -> output + """ + return _annotation_utils.annotate_linear_bn_act_helper( + model, quantization_config, filter_fn, use_bn=True + ) + + +@_AnnotationPatternRegistry.register("linear_bn") +def _annotate_linear_bn( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates input -> linear -> batch_norm -> output + """ + pattern_gm = _annotation_utils.get_linear_bn_pattern( + act_fn=None, act_in_place=False + ) + return _annotation_utils.annotate_weighted_mod_pattern( + model, pattern_gm, quantization_config, filter_fn + ) + + +@_AnnotationPatternRegistry.register("linear") +def _annotate_linear( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates input -> linear -> output + """ + pattern_gm = _annotation_utils.get_linear_pattern(act_fn=None) + return _annotation_utils.annotate_weighted_mod_pattern( + model, pattern_gm, quantization_config, filter_fn + ) + + +@_AnnotationPatternRegistry.register("add_act") +def _annotate_add_act( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates + + input_1 --- + \ + --> add -> activation -> output + / + input_2 --- + """ + ops = [_operator.add, _torch.add, _operator.iadd] + return _annotation_utils.annotate_binary_op_helper( + model, ops, quantization_config, filter_fn + ) + + +@_AnnotationPatternRegistry.register("add") +def _annotate_add( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates + + input_1 --- + \ + --> add -> output + / + input_2 --- + """ + annotated_partitions = [] + ops = [_operator.add, _torch.add, _operator.iadd] + for binary_op in ops: + pattern_gm = _annotation_utils.get_binary_op_act_pattern(binary_op, None) + annotated_partitions.extend( + _annotation_utils.annotate_binary_op_act_pattern( + model, pattern_gm, quantization_config, filter_fn + ) + ) + return annotated_partitions + + +@_AnnotationPatternRegistry.register("mul_act") +def _annotate_mul_act( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates + + input_1 --- + \ + --> mul -> activation -> output + / + input_2 --- + """ + ops = [_operator.mul, _torch.mul, _operator.imul] + return _annotation_utils.annotate_binary_op_helper( + model, ops, quantization_config, filter_fn + ) + + +@_AnnotationPatternRegistry.register("mul") +def _annotate_mul( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates + + input_1 --- + \ + --> mul -> output + / + input_2 --- + """ + annotated_partitions = [] + ops = [_operator.mul, _torch.mul, _operator.imul] + for binary_op in ops: + pattern_gm = _annotation_utils.get_binary_op_act_pattern(binary_op, None) + annotated_partitions.extend( + _annotation_utils.annotate_binary_op_act_pattern( + model, pattern_gm, quantization_config, filter_fn + ) + ) + return annotated_partitions + + +@_AnnotationPatternRegistry.register("matmul_act") +def _annotate_matmul_act( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates + + input_1 --- + \ + --> matmul -> activation -> output + / + input_2 --- + """ + return _annotation_utils.annotate_binary_op_helper( + model, [_torch.matmul], quantization_config, filter_fn + ) + + +@_AnnotationPatternRegistry.register("matmul") +def _annotate_matmul( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates + + input_1 --- + \ + --> matmul -> output + / + input_2 --- + """ + pattern_gm = _annotation_utils.get_binary_op_act_pattern(_torch.matmul, None) + return _annotation_utils.annotate_binary_op_act_pattern( + model, pattern_gm, quantization_config, filter_fn + ) + + +@_AnnotationPatternRegistry.register("max_pool1d") +def _annotate_max_pool1d( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates + + input -> max_pool1d -> output + """ + return _annotation_utils.annotate_unary_shared_observer_ops( + model, + [_torch.nn.MaxPool1d, _torch.nn.functional.max_pool1d, _torch.max_pool1d], + quantization_config, + filter_fn, + ) + + +@_AnnotationPatternRegistry.register("max_pool2d") +def _annotate_max_pool2d( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates + + input -> max_pool2d -> output + """ + return _annotation_utils.annotate_unary_shared_observer_ops( + model, + [_torch.nn.MaxPool2d, _torch.nn.functional.max_pool2d, _torch.max_pool2d], + quantization_config, + filter_fn, + ) + + +@_AnnotationPatternRegistry.register("max_pool3d") +def _annotate_max_pool3d( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates + + input -> max_pool3d -> output + """ + return _annotation_utils.annotate_unary_shared_observer_ops( + model, + [_torch.nn.MaxPool3d, _torch.nn.functional.max_pool3d, _torch.max_pool3d], + quantization_config, + filter_fn, + ) + + +@_AnnotationPatternRegistry.register("adaptive_avg_pool1d") +def _annotate_adaptive_avg_pool1d( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates + + input -> adaptive_avg_pool1d -> output + """ + return _annotation_utils.annotate_unary_shared_observer_ops( + model, + [ + _torch.nn.AdaptiveAvgPool1d, + _torch.nn.functional.adaptive_avg_pool1d, + _torch.adaptive_avg_pool1d, + ], + quantization_config, + filter_fn, + ) + + +@_AnnotationPatternRegistry.register("adaptive_avg_pool2d") +def _annotate_adaptive_avg_pool2d( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates + + input -> adaptive_avg_pool2d -> output + """ + return _annotation_utils.annotate_unary_shared_observer_ops( + model, + [_torch.nn.AdaptiveAvgPool2d, _torch.nn.functional.adaptive_avg_pool2d], + quantization_config, + filter_fn, + ) + + +@_AnnotationPatternRegistry.register("adaptive_avg_pool3d") +def _annotate_adaptive_avg_pool3d( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates + + input -> adaptive_avg_pool3d -> output + """ + return _annotation_utils.annotate_unary_shared_observer_ops( + model, + [_torch.nn.AdaptiveAvgPool3d, _torch.nn.functional.adaptive_avg_pool3d], + quantization_config, + filter_fn, + ) + + +@_AnnotationPatternRegistry.register("avg_pool1d") +def _annotate_avg_pool1d( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates + + input -> avg_pool1d -> output + """ + return _annotation_utils.annotate_unary_shared_observer_ops( + model, + [ + _torch.nn.AvgPool1d, + _torch.nn.functional.avg_pool1d, + _torch.avg_pool1d, + _torch.mean, + ], + quantization_config, + filter_fn, + ) + + +@_AnnotationPatternRegistry.register("avg_pool2d") +def _annotate_avg_pool2d( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates + + input -> avg_pool2d -> output + """ + return _annotation_utils.annotate_unary_shared_observer_ops( + model, + [ + _torch.nn.AvgPool2d, + _torch.nn.functional.avg_pool2d, + _torch._C._nn.avg_pool2d, + ], + quantization_config, + filter_fn, + ) + + +@_AnnotationPatternRegistry.register("avg_pool3d") +def _annotate_avg_pool3d( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates + + input -> avg_pool3d -> output + """ + return _annotation_utils.annotate_unary_shared_observer_ops( + model, + [ + _torch.nn.AvgPool3d, + _torch.nn.functional.avg_pool3d, + _torch._C._nn.avg_pool3d, + ], + quantization_config, + filter_fn, + ) + + +@_AnnotationPatternRegistry.register("flatten") +def _annotate_flatten( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates + + input -> flatten -> output + """ + return _annotation_utils.annotate_unary_shared_observer_ops( + model, + [ + _torch.nn.Flatten, + _torch.flatten, + ], + quantization_config, + filter_fn, + ) + + +class CoreMLQuantizer(_TorchQuantizer): + """ + Annotates all recognized patterns using ``config``. + + Extends py:class:`torch.ao.quantization.quantizer.quantizer.Quantizer` to + add support for quantization patterns supported by Core ML. + + Use it in conjunction with PyTorch 2.0 ``prepare_pt2e`` and ``prepare_qat_pt2e`` APIs + for post training weight and activation quantization using calibration data and + for quantization aware training (QAT). + + Example: + + .. code-block:: python + + import torch.nn as nn + from torch.export import export_for_training + ffrom torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_qat_pt2e + + from coremltools.optimize.torch.quantization._coreml_quantizer import CoreMLQuantizer + + model = nn.Sequential( + OrderedDict( + { + "conv": nn.Conv2d(1, 20, (3, 3)), + "relu1": nn.ReLU(), + "conv2": nn.Conv2d(20, 20, (3, 3)), + "relu2": nn.ReLU(), + } + ) + ) + + loss_fn = define_loss() + + # initialize the annotator with quantization config + config = LinearQuantizerConfig.from_dict( + { + "global_config": { + "quantization_scheme": "symmetric", + } + } + ) + quantizer = CoreMLQuantizer(config) + + example_inputs = torch.randn(1, 1, 28, 28) + + # create export graph + exported_model = export_for_training(model, (example_inputs,)).module() + + # prepare the model to insert FakeQuantize layers for QAT + prepared_model = prepare_qat_pt2e(exported_model, quantizer) + + # use prepared model in your PyTorch training loop + for inputs, labels in data: + output = prepared_model(inputs) + loss = loss_fn(output, labels) + loss.backward() + optimizer.step() + # turn observers/quantizers on/off depending on iteration number + + # convert operations to their quanitzed counterparts using parameters learnt via QAT + model = convert_pt2e(prepared_model) + """ + + def __init__(self, config: _Optional[_LinearQuantizerConfig]): + self._config = config + + def _annotate_all_patterns( + self, + model: _torch.fx.GraphModule, + quantization_config: _Optional[_ModuleLinearQuantizerConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, + ): + annotators = _AnnotationPatternRegistry.get_annotators() + for _, annotator in annotators.items(): + annotation_config = _AnnotationConfig.from_quantization_config( + quantization_config + ) + annotator(model, annotation_config, filter_fn) + + def annotate(self, model: _torch.fx.GraphModule) -> _torch.fx.GraphModule: + # First annotate all modules/operations which have name based configs + module_name_list = list(self._config.module_name_configs.keys()) + for module_name, config in self._config.module_name_configs.items(): + self._annotate_all_patterns( + model, config, get_module_name_filter(module_name) + ) + + # Next annotate all modules/operations which have type based configs + tp_list = list(self._config.module_type_configs.keys()) + for module_type, config in self._config.module_type_configs.items(): + self._annotate_all_patterns( + model, config, _annotation_utils.get_object_type_filter(module_type) + ) + + # Annotate all other modules/operations + self._annotate_all_patterns( + model, + self._config.global_config, + _annotation_utils.get_not_object_type_or_name_filter( + tp_list, module_name_list + ), + ) + return model + + def validate(self, model: _torch.fx.GraphModule) -> None: + pass diff --git a/backends/apple/coreml/quantizer/_coreml_quantizer_utils.py b/backends/apple/coreml/quantizer/_coreml_quantizer_utils.py new file mode 100644 index 00000000000..e7389c0135f --- /dev/null +++ b/backends/apple/coreml/quantizer/_coreml_quantizer_utils.py @@ -0,0 +1,800 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +import itertools as _itertools +from typing import ( + Callable as _Callable, + List as _List, + Optional as _Optional, + Tuple as _Tuple, +) + +import torch as _torch +import torch.nn.functional as _F + +from executorch.backends.apple.coreml.quantizer._annotation_config import ( + AnnotationConfig as _AnnotationConfig, +) + +from torch.fx import Node as _Node +from torch.fx.passes.utils.matcher_with_name_node_map_utils import ( + SubgraphMatcherWithNameNodeMap as _SubgraphMatcherWithNameNodeMap, +) +from torch.fx.passes.utils.source_matcher_utils import ( + get_source_partitions as _get_source_partitions, +) + +from torchao.quantization.pt2e.quantizer.quantizer import ( + FixedQParamsQuantizationSpec as _FixedQParamsQuantizationSpec, + Q_ANNOTATION_KEY, + QuantizationAnnotation as _QuantizationAnnotation, + QuantizationSpec as _TorchQuantizationSpec, + QuantizationSpecBase as _TorchQuantizationSpecBase, + SharedQuantizationSpec as _SharedQuantizationSpec, +) + +from torchao.quantization.pt2e.quantizer.utils import get_module_name_filter + +from torchao.quantization.pt2e.utils import _get_aten_graph_module_for_pattern + + +def _is_annotated(nodes: list[_Node]): + """ + Given a list of nodes (that represents an operator pattern), + check if any of the node is annotated, return True if any of the node + is annotated, otherwise return False + """ + annotated = False + for node in nodes: + annotated = annotated or ( + Q_ANNOTATION_KEY in node.meta and node.meta[Q_ANNOTATION_KEY]._annotated + ) + return annotated + + +def _mark_nodes_as_annotated(nodes: list[_Node]): + for node in nodes: + if node is not None: + if Q_ANNOTATION_KEY not in node.meta: + node.meta[Q_ANNOTATION_KEY] = _QuantizationAnnotation() + node.meta[Q_ANNOTATION_KEY]._annotated = True + + +# All activations recognized for conv-act/conv-bn-act patterns +_supported_activations = ( + _F.relu, + _F.relu6, + _F.leaky_relu, + _F.silu, + _F.elu, + _F.celu, + _F.selu, + _F.mish, + _F.hardtanh, + _F.hardswish, + _F.hardsigmoid, +) + + +# These activation functions don't have an inplace argument +_supported_activations_no_inplace = (_F.gelu, _F.sigmoid, _F.logsigmoid, _F.tanh) + + +# Map of dimension to convolution function +_conv_fn_map = {1: _F.conv1d, 2: _F.conv2d, 3: _F.conv3d} + + +def _get_aten_graph_module( + pattern: _torch.nn.Module, + example_inputs: _Tuple[_torch.Tensor], + is_cuda: bool = False, +): + return _get_aten_graph_module_for_pattern(pattern, example_inputs, is_cuda) + + +def _adjust_activation_qspec( + node: _torch.fx.Node, qspec: _Optional[_TorchQuantizationSpecBase] +) -> _Optional[_TorchQuantizationSpecBase]: + """ + Adjust quantization spec for ops which can use fixed qparams + or ops for which we can use affine quantization mode during + symmetric quantization because their output is always positive. + """ + if qspec is None: + return qspec + + tanh_qspec = _FixedQParamsQuantizationSpec( + dtype=_torch.uint8, + scale=2.0 / 256.0, + zero_point=128, + quant_min=0, + quant_max=255, + qscheme=_torch.per_tensor_symmetric, + ) + + sigmoid_qspec = _FixedQParamsQuantizationSpec( + dtype=_torch.uint8, + scale=1.0 / 256.0, + zero_point=0, + quant_min=0, + quant_max=255, + qscheme=_torch.per_tensor_affine, + ) + + fixed_q_params_ops = { + _torch.ops.aten.tanh.default: tanh_qspec, + _torch.ops.aten.tanh_.default: tanh_qspec, + _torch.ops.aten.sigmoid.default: sigmoid_qspec, + _torch.ops.aten.sigmoid_.default: sigmoid_qspec, + _torch.ops.aten.hardsigmoid.default: sigmoid_qspec, + _torch.ops.aten.hardsigmoid_.default: sigmoid_qspec, + } + + always_affine_ops = ( + _torch.ops.aten.relu.default, + _torch.ops.aten.relu_.default, + _torch.ops.aten.relu6.default, + _torch.ops.aten.relu6_.default, + ) + + # ReLU6 activation maps to _torch.ops.aten.hardtanh.default with + # min_val = 0 and max_val = 6 + is_always_affine_op = node.target in always_affine_ops or ( + node.target + in [_torch.ops.aten.hardtanh.default, _torch.ops.aten.hardtanh_.default] + and node.args[1] == 0 # min_val, corresponding to ReLU6 + and node.args[2] == 6 # max_val, corresponding to ReLU6 + ) + + if node.target in fixed_q_params_ops: + return _TorchQuantizationSpec( + observer_or_fake_quant_ctr=qspec.observer_or_fake_quant_ctr, + dtype=qspec.dtype, + qscheme=fixed_q_params_ops[node.target].qscheme, + ) + # FIXME: Because of a bug in PyTorch in function _create_obs_or_fq_from_qspec + # in module torch/ao/quantization/fx/prepare.py which creates a + # FixedQParamsFakeQuantize partial, instead of an instance, we cannot + # actually create FixedQParamsQuantizationSpec + if is_always_affine_op: + return _TorchQuantizationSpec( + observer_or_fake_quant_ctr=qspec.observer_or_fake_quant_ctr, + dtype=qspec.dtype, + qscheme=_torch.per_tensor_affine, + ) + return qspec + + +def get_object_type_filter(tp: _Callable): + """ + Returns a filter which returns True if a node in the final exported graph + was created because of an object of type ``tp`` + """ + + def object_type_filter(n: _Node) -> bool: + # example: { + # 'add_10': ('add', ) + # } + nn_module_stack = n.meta.get("nn_module_stack", {}) + types = [t for _, t in nn_module_stack.values()] + source_fn_stack = n.meta.get("source_fn_stack", {}) + types.extend([t for _, t in source_fn_stack]) + return tp in types + + return object_type_filter + + +def get_not_object_type_or_name_filter( + tp_list: _List[_Callable], module_name_list: _List[str] +) -> _Callable[[_Node], bool]: + """ + Returns a filter which returns True if a node in the final exported graph + was not created using any modules with names in ``module_name_list`` or objects with + type in ``tp_list``. + """ + object_type_filters = [get_object_type_filter(tp) for tp in tp_list] + module_name_list_filters = [get_module_name_filter(m) for m in module_name_list] + + def not_object_type_or_name_filter(n: _Node) -> bool: + return not any(f(n) for f in object_type_filters + module_name_list_filters) + + return not_object_type_or_name_filter + + +def _get_weighted_mod_pattern( + mod_fn: _Callable, + example_inputs: _Tuple[_torch.Tensor, ...], + act_fn: _Optional[_Callable] = None, + act_in_place: bool = False, +) -> _torch.nn.Module: + """ + Returns an aten graph corresponding to a sequence of these ops: + input -> weighted_mod -> activation -> output + + A weighted mod is a module which has a weight and bias, such as a + convolution module or a linear module. + + No activation is used if ``act_fn`` is ``None``. + ``act_fn`` is an activation function from _supported_activations or + _supported_activations_no_inplace + """ + + class Pattern(_torch.nn.Module): + def forward(self, input, weight, bias): + mod_out = mod_fn(input, weight, bias) + output = mod_out + node_dict = { + "input": input, + "mod": mod_out, + "weight": weight, + "bias": bias, + } + if act_fn is not None: + # Only add output if activation function is applied to model output + output = ( + act_fn(output, inplace=True) if act_in_place else act_fn(output) + ) + node_dict["output"] = output + return output, node_dict + + return _get_aten_graph_module(Pattern(), example_inputs) + + +def _get_weighted_mod_bn_pattern( + mod_fn: _Callable, + example_inputs: _Tuple[_torch.Tensor, ...], + act_fn: _Optional[_Callable] = None, + act_in_place: bool = False, +) -> _torch.nn.Module: + """ + Returns an aten graph corresponding to a sequence of these ops: + input -> weighted_mod -> batch_norm -> activation -> output + + A weighted mod is a module which has a weight and bias, such as a + convolution module or a linear module. + + No activation is used if ``act_fn`` is ``None``. + ``act_fn`` is an activation function from _supported_activations or + _supported_activations_no_inplace + """ + + class Pattern(_torch.nn.Module): + def forward( + self, input, weight, bias, bn_weight, bn_bias, bn_run_mean, bn_run_var + ): + mod_out = mod_fn(input, weight, bias) + output = _F.batch_norm( + mod_out, bn_run_mean, bn_run_var, bn_weight, bn_bias, training=True + ) + if act_fn is not None: + output = ( + act_fn(output, inplace=True) if act_in_place else act_fn(output) + ) + return output, { + "input": input, + "mod": mod_out, + "weight": weight, + "bias": bias, + "output": output, + } + + return _get_aten_graph_module(Pattern(), example_inputs) + + +def get_binary_op_act_pattern( + binary_op: _Callable, + act_fn: _Optional[_Callable] = None, + act_in_place: bool = False, +) -> _torch.nn.Module: + """ + Returns an aten graph corresponding to a sequence of these ops: + input_1 --- + \ + --> binary_op -> activation -> output + / + input_2 --- + + A binary op is any operation which consumes two inputs to create one output, + such as addition or multiplication. + + No activation is used if ``act_fn`` is ``None``. + ``act_fn`` is an activation function from _supported_activations or + _supported_activations_no_inplace + """ + + class Pattern(_torch.nn.Module): + def forward(self, input_1, input_2): + binary_op_out = binary_op(input_1, input_2) + node_dict = { + "binary_op": binary_op_out, + } + output = binary_op_out + if act_fn is not None: + output = ( + act_fn(output, inplace=True) if act_in_place else act_fn(output) + ) + node_dict["output"] = output + return output, node_dict + + example_inputs = (_torch.randn(1), _torch.randn(1)) + return _get_aten_graph_module(Pattern(), example_inputs) + + +def get_conv_pattern( + conv_dim: int, act_fn: _Optional[_Callable] = None, act_in_place: bool = False +) -> _torch.nn.Module: + """ + Returns an aten graph corresponding to a sequence of these ops: + input -> conv -> activation -> output + + No activation is used if ``act_fn`` is ``None``. + ``act_fn`` is an activation function from _supported_activations or + _supported_activations_no_inplace + """ + assert ( + conv_dim in _conv_fn_map + ), f"Dimension {conv_dim} is not supported for Convolution layers." + + example_inputs = ( + _torch.randn(1, 1, *[3] * conv_dim), # input + _torch.randn(1, 1, *[1] * conv_dim), # conv weight + _torch.randn(1), # conv bias + ) + return _get_weighted_mod_pattern( + _conv_fn_map[conv_dim], example_inputs, act_fn, act_in_place + ) + + +def get_conv_bn_pattern( + conv_dim: int, act_fn: _Optional[_Callable] = None, act_in_place: bool = False +) -> _torch.nn.Module: + """ + Returns an aten graph corresponding to a sequence of these ops: + input -> conv -> batch_norm -> activation -> output + + No activation is used if ``act_fn`` is ``None``. + ``act_fn`` is an activation function from _supported_activations or + _supported_activations_no_inplace + """ + assert ( + conv_dim in _conv_fn_map + ), f"Dimension {conv_dim} is not supported for Convolution layers." + + example_inputs = ( + _torch.randn(1, 1, *[3] * conv_dim), # input + _torch.randn(1, 1, *[1] * conv_dim), # conv weight + _torch.randn(1), # conv bias + _torch.randn(1), # bn_weight + _torch.randn(1), # bn_bias + _torch.randn(1), # bn_run_mean + _torch.randn(1), # bn_run_var + ) + return _get_weighted_mod_bn_pattern( + _conv_fn_map[conv_dim], example_inputs, act_fn, act_in_place + ) + + +def get_linear_pattern( + act_fn: _Optional[_Callable] = None, act_in_place: bool = False +) -> _torch.nn.Module: + """ + Returns an aten graph corresponding to a sequence of these ops: + input -> linear -> activation -> output + + No activation is used if ``act_fn`` is ``None``. + ``act_fn`` is an activation function from _supported_activations or + _supported_activations_no_inplace + """ + example_inputs = ( + _torch.randn(1, 1), # input + _torch.randn(3, 1), # linear weight + _torch.randn(3), # linear bias + ) + return _get_weighted_mod_pattern(_F.linear, example_inputs, act_fn, act_in_place) + + +def get_linear_bn_pattern( + act_fn: _Optional[_Callable] = None, act_in_place: bool = False +) -> _torch.nn.Module: + """ + Returns an aten graph corresponding to a sequence of these ops: + input -> linear -> batch_norm -> activation -> output + + No activation is used if ``act_fn`` is ``None``. + ``act_fn`` is an activation function from _supported_activations or + _supported_activations_no_inplace + """ + example_inputs = ( + _torch.randn(2, 1), # input + _torch.randn(3, 1), # linear weight + _torch.randn(3), # linear bias + _torch.randn(3), # bn_weight + _torch.randn(3), # bn_bias + _torch.randn(3), # bn_run_mean + _torch.randn(3), # bn_run_var + ) + return _get_weighted_mod_bn_pattern(_F.linear, example_inputs, act_fn, act_in_place) + + +def annotate_weighted_mod_pattern( # noqa: C901 + model: _torch.fx.GraphModule, + pattern_gm: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]], +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates all nodes in ``model``, which match the pattern specified by ``pattern_gm`` + using ``quantization_config``. + + ``pattern_gm`` captures patterns of the following type: + + input -> weighted_mod -> batch_norm -> activation -> output + + batch_norm and activation may or may not be applied in the pattern. + + Only annotates those patterns in which all nodes return True when ``filter_fn`` is applied + to them. + """ + model.graph.eliminate_dead_code() + model.recompile() + + matcher = _SubgraphMatcherWithNameNodeMap(pattern_gm, ignore_literals=True) + matches = matcher.match(model.graph) + + annotated_partitions = [] + for match in matches: + name_node_map = match.name_node_map + input_node = name_node_map["input"] + mod_node = name_node_map["mod"] + weight_node = name_node_map["weight"] + bias_node = name_node_map["bias"] + if "output" in name_node_map: + # In this case, an activation is applied to the weighted module output + output_node = name_node_map["output"] + # If the output is same as mod_node, it means we have an inplace activation, + # so we need to correct the mod_node + if mod_node == output_node: + mod_node = mod_node.args[0] + else: + output_node = None + + # Validate mod args + if mod_node.args[0] is not input_node: + raise ValueError( + f"Weighted module arg did not contain input node {input_node}" + ) + if mod_node.args[1] is not weight_node: + raise ValueError( + f"Weighted module arg did not contain weight node {weight_node}" + ) + if len(mod_node.args) > 2 and mod_node.args[2] is not bias_node: + raise ValueError( + f"Weighted module arg did not contain bias node {bias_node}" + ) + + # Skip if the partition is already annotated or is filtered out by the user + partition = [mod_node, weight_node] + if bias_node is not None: + partition.append(bias_node) + if _is_annotated(partition): + continue + if filter_fn and any(not filter_fn(n) for n in partition): + continue + + # Annotate conv inputs and pattern output + input_qspec_map = {} + if not _is_annotated([input_node]): + input_qspec_map[input_node] = ( + quantization_config.input_activation if quantization_config else None + ) + else: + input_qspec_map[input_node] = input_node.meta[Q_ANNOTATION_KEY].output_qspec + + input_qspec_map[weight_node] = ( + quantization_config.weight if quantization_config else None + ) + output_qspec = ( + quantization_config.output_activation if quantization_config else None + ) + if bias_node is not None: + input_qspec_map[bias_node] = None + + if output_node is None: + mod_node.meta[Q_ANNOTATION_KEY] = _QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_qspec, + _annotated=True, + ) + else: + mod_node.meta[Q_ANNOTATION_KEY] = _QuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + ) + if not _is_annotated([output_node]): + output_qspec = _adjust_activation_qspec( + node=output_node, qspec=output_qspec + ) + output_node.meta[Q_ANNOTATION_KEY] = _QuantizationAnnotation( + output_qspec=output_qspec, + _annotated=True, + ) + + _mark_nodes_as_annotated(partition) + annotated_partitions.append(partition) + return annotated_partitions + + +def annotate_binary_op_act_pattern( + model: _torch.fx.GraphModule, + pattern_gm: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates all nodes in ``model``, which match the pattern specified by ``pattern_gm`` + using ``quantization_config``. + + ``pattern_gm`` captures patterns of the following type: + + input_1 --- + \ + --> binary_op -> activation -> output + / + input_2 --- + + activation may or may not be applied in the pattern. + + Only annotates those patterns in which all nodes return True when ``filter_fn`` is applied + to them. + """ + model.graph.eliminate_dead_code() + model.recompile() + + matcher = _SubgraphMatcherWithNameNodeMap(pattern_gm, ignore_literals=True) + matches = matcher.match(model.graph) + + annotated_partitions = [] + for match in matches: + name_node_map = match.name_node_map + binary_op_node: _Node = name_node_map["binary_op"] + if "output" in name_node_map: + output_node = name_node_map["output"] + # In this case, an activation is applied to the weighted module output + output_node = name_node_map["output"] + # If the output is same as binary_op_node, it means we have an inplace activation, + # so we need to correct the binary_op_node + if binary_op_node == output_node: + binary_op_node = binary_op_node.args[0] + partition = [output_node, binary_op_node] + else: + output_node = None + partition = [binary_op_node] + + if output_node is not None and len(binary_op_node.users) > 1: + raise ValueError("Binary op with activation has more than one users.") + + if _is_annotated(partition): + continue + + if filter_fn and any(not filter_fn(n) for n in partition): + continue + + input_act_qspec = ( + quantization_config.input_activation if quantization_config else None + ) + output_act_qspec = ( + quantization_config.output_activation if quantization_config else None + ) + + input_qspec_map = {} + input_act0 = binary_op_node.args[0] + if isinstance(input_act0, _Node): + input_qspec_map[input_act0] = input_act_qspec + + input_act1 = binary_op_node.args[1] + if isinstance(input_act1, _Node): + input_qspec_map[input_act1] = input_act_qspec + + if output_node is None: + binary_op_node.meta[Q_ANNOTATION_KEY] = _QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_act_qspec, + _annotated=True, + ) + else: + binary_op_node.meta[Q_ANNOTATION_KEY] = _QuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + ) + output_act_qspec = _adjust_activation_qspec( + node=output_node, qspec=output_act_qspec + ) + output_node.meta[Q_ANNOTATION_KEY] = _QuantizationAnnotation( + output_qspec=output_act_qspec, + _annotated=True, + ) + _mark_nodes_as_annotated(partition) + annotated_partitions.append(partition) + return annotated_partitions + + +def annotate_unary_shared_observer_ops( + model: _torch.fx.GraphModule, + ops: _List[_Callable], + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates all nodes in ``model``, which correspond to unary ops specified in ``ops``. + + input --> op --> output + + input and output nodes share the same quantization parameters. + """ + partitions = _get_source_partitions(model.graph, ops, filter_fn) + annotated_partitions = [] + for _, op_partitions in partitions.items(): + for partition in op_partitions: + output_node = partition.output_nodes[0] + op_node = partition.nodes[0] + if _is_annotated([output_node, op_node]): + continue + + input_node = op_node.args[0] + + input_act_qspec = ( + quantization_config.input_activation if quantization_config else None + ) + output_act_qspec = ( + quantization_config.output_activation if quantization_config else None + ) + + if ( + Q_ANNOTATION_KEY not in input_node.meta + or not input_node.meta[Q_ANNOTATION_KEY]._annotated + or input_node.meta[Q_ANNOTATION_KEY].output_qspec is None + or input_act_qspec is None + or output_act_qspec is None + ): + continue + + # input and output of op will share quantization parameter with input of op + act_qspec = _SharedQuantizationSpec(input_node) + op_node.meta[Q_ANNOTATION_KEY] = _QuantizationAnnotation( + input_qspec_map={ + input_node: act_qspec, + }, + _annotated=True, + ) + output_node.meta[Q_ANNOTATION_KEY] = _QuantizationAnnotation( + output_qspec=act_qspec, + _annotated=True, + ) + annotated_partitions.append(partition.nodes) + return annotated_partitions + + +def annotate_conv_bn_act_helper( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, + use_bn: bool = False, +) -> _Optional[_List[_List[_Node]]]: + """ + A helper function for annotating all patterns involving convolution operations, i.e., + + input -> conv -> batch_norm -> activation -> output + + conv can be either 1D, 2D or 3D + batch_norm and activation may or may not be applied. + """ + annotated_partitions = [] + + pattern_map = { + True: get_conv_bn_pattern, + False: get_conv_pattern, + } + + conv_dims = [1, 2, 3] + combinations = _itertools.product(conv_dims, _supported_activations, [True, False]) + for conv_dim, act_fn, act_in_place in combinations: + pattern_gm = pattern_map[use_bn](conv_dim, act_fn, act_in_place) + annotated_partitions.extend( + annotate_weighted_mod_pattern( + model, pattern_gm, quantization_config, filter_fn + ) + ) + + combinations = _itertools.product(conv_dims, _supported_activations_no_inplace) + for conv_dim, act_fn in combinations: + pattern_gm = pattern_map[use_bn](conv_dim, act_fn, act_in_place=False) + annotated_partitions.extend( + annotate_weighted_mod_pattern( + model, pattern_gm, quantization_config, filter_fn + ) + ) + + return annotated_partitions + + +def annotate_linear_bn_act_helper( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, + use_bn: bool = False, +) -> _Optional[_List[_List[_Node]]]: + """ + A helper function for annotating all patterns involving linear operations, i.e., + + input -> linear -> batch_norm -> activation -> output + + batch_norm and activation may or may not be applied. + """ + annotated_partitions = [] + + pattern_map = { + True: get_linear_bn_pattern, + False: get_linear_pattern, + } + + combinations = _itertools.product(_supported_activations, [True, False]) + for act_fn, act_in_place in combinations: + pattern_gm = pattern_map[use_bn](act_fn, act_in_place) + annotated_partitions.extend( + annotate_weighted_mod_pattern( + model, pattern_gm, quantization_config, filter_fn + ) + ) + + for act_fn in _supported_activations_no_inplace: + pattern_gm = pattern_map[use_bn](act_fn, act_in_place=False) + annotated_partitions.extend( + annotate_weighted_mod_pattern( + model, pattern_gm, quantization_config, filter_fn + ) + ) + + return annotated_partitions + + +def annotate_binary_op_helper( + model: _torch.fx.GraphModule, + binary_ops: _List[_Callable], + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + A helper function for annotating all patterns involving binary operations, i.e., + using ``quantization_config``. + + input_1 --- + \ + --> binary_op -> activation -> output + / + input_2 --- + + activation may or may not be applied in the pattern. + """ + annotated_partitions = [] + + combinations = _itertools.product(binary_ops, _supported_activations, [True, False]) + for binary_op, act_fn, act_in_place in combinations: + pattern_gm = get_binary_op_act_pattern(binary_op, act_fn, act_in_place) + annotated_partitions.extend( + annotate_binary_op_act_pattern( + model, pattern_gm, quantization_config, filter_fn + ) + ) + + combinations = _itertools.product(binary_ops, _supported_activations_no_inplace) + for binary_op, act_fn in combinations: + pattern_gm = get_binary_op_act_pattern(binary_op, act_fn, act_in_place=False) + annotated_partitions.extend( + annotate_binary_op_act_pattern( + model, pattern_gm, quantization_config, filter_fn + ) + ) + + return annotated_partitions diff --git a/backends/apple/coreml/quantizer/coreml_quantizer.py b/backends/apple/coreml/quantizer/coreml_quantizer.py index ad596f7dfc9..a62c61967ef 100644 --- a/backends/apple/coreml/quantizer/coreml_quantizer.py +++ b/backends/apple/coreml/quantizer/coreml_quantizer.py @@ -2,6 +2,6 @@ # # Please refer to the license found in the LICENSE file in the root directory of the source tree. -from coremltools.optimize.torch.quantization._coreml_quantizer import ( # noqa: FLAKE8 F401 +from executorch.backends.apple.coreml.quantizer._coreml_quantizer import ( # noqa: FLAKE8 F401 CoreMLQuantizer, ) diff --git a/backends/apple/coreml/test/test_coreml_quantizer.py b/backends/apple/coreml/test/test_coreml_quantizer.py index eb8b9471345..97571b6c758 100644 --- a/backends/apple/coreml/test/test_coreml_quantizer.py +++ b/backends/apple/coreml/test/test_coreml_quantizer.py @@ -16,6 +16,8 @@ from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer from torch.export import export + +from torchao.quantization.pt2e.fake_quantize import FakeQuantizeBase from torchao.quantization.pt2e.quantize_pt2e import ( convert_pt2e, prepare_pt2e, @@ -23,6 +25,29 @@ ) +def _get_quantization_config(): + return LinearQuantizerConfig.from_dict( + { + "global_config": { + "quantization_scheme": QuantizationScheme.symmetric, + "milestones": [0, 0, 10, 10], + "activation_dtype": torch.quint8, + "weight_dtype": torch.qint8, + "weight_per_channel": True, + } + } + ) + + +def _count_fake_quantize_modules(model: torch.nn.Module) -> int: + """Count the number of FakeQuantizeBase modules in a model.""" + count = 0 + for module in model.modules(): + if isinstance(module, FakeQuantizeBase): + count += 1 + return count + + class TestCoreMLQuantizer: @staticmethod def quantize_and_compare( @@ -34,17 +59,7 @@ def quantize_and_compare( pre_autograd_aten_dialect = export(model, example_inputs, strict=True).module() - quantization_config = LinearQuantizerConfig.from_dict( - { - "global_config": { - "quantization_scheme": QuantizationScheme.symmetric, - "milestones": [0, 0, 10, 10], - "activation_dtype": torch.quint8, - "weight_dtype": torch.qint8, - "weight_per_channel": True, - } - } - ) + quantization_config = _get_quantization_config() quantizer = CoreMLQuantizer(quantization_config) if quantization_type == "PTQ": @@ -54,13 +69,55 @@ def quantize_and_compare( else: raise ValueError("Invalid quantization type") + print("Prepared graph:", prepared_graph) prepared_graph(*example_inputs) converted_graph = convert_pt2e(prepared_graph) + print("Converted graph:", converted_graph) model_output = model(*example_inputs).detach().numpy() quantized_output = converted_graph(*example_inputs).detach().numpy() np.testing.assert_allclose(quantized_output, model_output, rtol=5e-2, atol=5e-2) + @pytest.mark.parametrize("quantization_type", ("PTQ", "QAT")) + def test_fake_quantize_modules_inserted_after_prepare(self, quantization_type): + """Test that FakeQuantizeBase modules are inserted after prepare step.""" + SHAPE = (1, 3, 256, 256) + + class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d( + in_channels=3, out_channels=16, kernel_size=3, padding=1 + ) + self.relu = torch.nn.ReLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + a = self.conv(x) + return self.relu(a) + + model = Model() + example_inputs = (torch.randn(SHAPE),) + + pre_autograd_aten_dialect = export(model, example_inputs, strict=True).module() + + # Verify no FakeQuantize modules before prepare + assert _count_fake_quantize_modules(pre_autograd_aten_dialect) == 0 + + quantization_config = _get_quantization_config() + quantizer = CoreMLQuantizer(quantization_config) + + if quantization_type == "PTQ": + prepared_graph = prepare_pt2e(pre_autograd_aten_dialect, quantizer) + else: + prepared_graph = prepare_qat_pt2e(pre_autograd_aten_dialect, quantizer) + + # Verify FakeQuantize modules are present after prepare + fake_quant_count = _count_fake_quantize_modules(prepared_graph) + assert fake_quant_count > 0, ( + f"Expected FakeQuantizeBase modules after prepare_{quantization_type.lower()}_pt2e, " + f"but found {fake_quant_count}" + ) + @pytest.mark.parametrize("quantization_type", ("PTQ", "QAT")) def test_conv_relu(self, quantization_type): SHAPE = (1, 3, 256, 256) @@ -112,3 +169,4 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: test_runner = TestCoreMLQuantizer() test_runner.test_conv_relu("PTQ") test_runner.test_linear("QAT") + test_runner.test_fake_quantize_modules_inserted_after_prepare("PTQ")