diff --git a/backends/arm/quantizer/TARGETS b/backends/arm/quantizer/TARGETS index 28bfe15b528..52c3a1b1158 100644 --- a/backends/arm/quantizer/TARGETS +++ b/backends/arm/quantizer/TARGETS @@ -10,16 +10,17 @@ runtime.python_library( ], ) -# Exposed through __init__.py runtime.python_library( name = "arm_quantizer", srcs = ["arm_quantizer.py"], deps = [ ":arm_quantizer_utils", ":quantization_annotator", + ":quantizer_support", "//executorch/backends/arm:constants", "//executorch/backends/arm:ethosu", "//executorch/backends/arm:vgf", + "//executorch/backends/cortex_m/quantizer:quantizer", "//executorch/backends/arm/tosa:specification", "//executorch/backends/arm:arm_compile_spec", "//caffe2:torch", @@ -43,11 +44,24 @@ runtime.python_library( name = "arm_quantizer_utils", srcs = ["arm_quantizer_utils.py"], deps = [ + "//caffe2:torch", + "//executorch/backends/arm:common", + "//executorch/backends/arm:constants", ":quantization_config", "//pytorch/ao:torchao", ], ) +runtime.python_library( + name = "quantizer_support", + srcs = ["quantizer_support.py"], + deps = [ + ":quantization_annotator", + "//caffe2:torch", + "//executorch/backends/cortex_m/quantizer:quantizer", + ], +) + runtime.python_library( name = "lib", srcs = ["__init__.py"], diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index 4a7a0caaabc..e33555ced7c 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -9,18 +9,50 @@ # # Quantizer for Arm backend # - from __future__ import annotations import functools +import logging from typing import Any, Callable, Dict, List, Optional import torch +from executorch.backends.arm._passes import ArmPassManager +from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo from executorch.backends.arm.constants import DISALLOW_TFA_META_KEY from executorch.backends.arm.ethosu import EthosUCompileSpec - -from executorch.backends.arm.quantizer import QuantizationConfig +from executorch.backends.arm.quantizer.quantization_config import ( + QuantizationConfig, + TOSAQuantizationConfig, +) +from executorch.backends.arm.quantizer.quantizer_support import ( + TOSA_QUANTIZER_SUPPORT_DICT, +) from executorch.backends.arm.tosa import TosaSpecification +from executorch.backends.cortex_m.quantizer.node_finders import ( + GlobalNodeFinder, + InputNodeFinder, + ModuleNameNodeFinder, + ModuleTypeNodeFinder, + NodeNameNodeFinder, + NodeTargetNodeFinder, + OutputNodeFinder, +) +from executorch.backends.cortex_m.quantizer.pattern_matcher import PatternMatcher + +from executorch.backends.cortex_m.quantizer.quantizer_reporter import ( + QuantizerReporter, + SUPPORTED_QCONFIGS, + SUPPORTED_QSPECS, +) + +from torch._ops import OpOverload + +from torchao.quantization.pt2e.quantizer import ( + ComposableQuantizer, + QuantizationAnnotation, + Quantizer, +) +from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY from executorch.backends.arm.common.arm_compile_spec import ( ArmCompileSpec, ) # isort: skip @@ -28,8 +60,17 @@ get_cond_while_submodules_nested, is_submodule_node, ) -from executorch.backends.arm.vgf import VgfCompileSpec +from executorch.backends.arm.quantizer.arm_quantizer_utils import ( + _get_int32_bias_qspec, + _get_int32_per_channel_bias_qspec, + is_annotated, + mark_node_as_annotated, + NodeFinder, + PatternQuantizer, + SharedQspecQuantizer, +) +from executorch.backends.arm.vgf import VgfCompileSpec from torch.fx import GraphModule, Node from torchao.quantization.pt2e import ( FakeQuantize, @@ -52,12 +93,11 @@ annotate_output_qspec, get_module_name_filter, QuantizationSpec, - Quantizer, ) -from .arm_quantizer_utils import is_annotated, mark_node_as_annotated from .quantization_annotator import annotate_graph + __all__ = [ "TOSAQuantizer", "EthosUQuantizer", @@ -66,6 +106,8 @@ "get_symmetric_quantization_config", ] +logger = logging.getLogger(__name__) + @functools.lru_cache def get_symmetric_quantization_config( @@ -80,6 +122,9 @@ def get_symmetric_quantization_config( ) -> QuantizationConfig: """Create symmetric quantization config for activations and weights. + Activations use an affine qscheme; "symmetric" refers to the weight + quantization qscheme. + Args: is_per_channel (bool): Whether to use per-channel quantization for weights. @@ -166,16 +211,20 @@ def get_symmetric_quantization_config( observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr, ) - bias_quantization_spec = None + if is_per_channel: + bias_quantization_spec = _get_int32_per_channel_bias_qspec + else: + bias_quantization_spec = _get_int32_bias_qspec + if is_dynamic: - quantization_config = QuantizationConfig( + quantization_config = TOSAQuantizationConfig( act_quantization_spec, None, weight_quantization_spec, bias_quantization_spec, ) else: - quantization_config = QuantizationConfig( + quantization_config = TOSAQuantizationConfig( act_quantization_spec, act_quantization_spec, weight_quantization_spec, @@ -261,22 +310,58 @@ def get_symmetric_a16w8_quantization_config( ) # Replace activation quantization spec with 16-bit version if is_dynamic: - quantization_config = QuantizationConfig( + quantization_config = TOSAQuantizationConfig( act_quantization_spec, # 16-bit input activations None, base_config.weight, # 8-bit weights from base config - None, + base_config.bias, # bias from base config ) else: - quantization_config = QuantizationConfig( + quantization_config = TOSAQuantizationConfig( act_quantization_spec, # 16-bit input activations act_quantization_spec, # 16-bit output activations base_config.weight, # 8-bit weights from base config - None, + base_config.bias, # bias from base config ) return quantization_config +# Register supported quantization configs and qspecs in the reporter for human-readable reporting +# MLETORCH-1854: Temporary solution, refactor to automatically register these instead +_symmetric_a8w4_config_per_channel = get_symmetric_a8w4_quantization_config() +_symmetric_a8w8_config_per_channel = get_symmetric_quantization_config() +_symmetric_a16w8_config_per_channel = get_symmetric_a16w8_quantization_config() +_symmetric_a8w4_config_per_tensor = get_symmetric_a8w4_quantization_config( + is_per_channel=False +) +_symmetric_a8w8_config_per_tensor = get_symmetric_quantization_config( + is_per_channel=False +) +_symmetric_a16w8_config_per_tensor = get_symmetric_a16w8_quantization_config( + is_per_channel=False +) +SUPPORTED_QCONFIGS.update( + { + _symmetric_a8w8_config_per_channel: f"{__name__}.get_symmetric_quantization_config(is_per_channel=True)", + _symmetric_a16w8_config_per_channel: f"{__name__}.get_symmetric_a16w8_quantization_config(is_per_channel=True)", + _symmetric_a8w4_config_per_channel: f"{__name__}.get_symmetric_a8w4_quantization_config(is_per_channel=True)", + _symmetric_a8w8_config_per_tensor: f"{__name__}.get_symmetric_quantization_config(is_per_channel=False)", + _symmetric_a16w8_config_per_tensor: f"{__name__}.get_symmetric_a16w8_quantization_config(is_per_channel=False)", + _symmetric_a8w4_config_per_tensor: f"{__name__}.get_symmetric_a8w4_quantization_config(is_per_channel=False)", + } +) + +SUPPORTED_QSPECS.update( + { + _symmetric_a8w4_config_per_channel.get_weight_qspec(): "INT4_PER_CHANNEL_QSPEC", + _symmetric_a8w8_config_per_channel.get_weight_qspec(): "INT8_PER_CHANNEL_QSPEC", + _symmetric_a8w8_config_per_tensor.get_weight_qspec(): "INT8_PER_TENSOR_QSPEC", + _symmetric_a8w4_config_per_tensor.get_weight_qspec(): "INT4_PER_TENSOR_QSPEC", + _symmetric_a8w8_config_per_tensor.get_input_act_qspec(): "INT8_PER_TENSOR_QSPEC", + _symmetric_a16w8_config_per_tensor.get_input_act_qspec(): "INT16_PER_TENSOR_QSPEC", + } +) + NodeFilterType = Callable[[Node], bool] """Type for a Node Filter used by annotators. @@ -359,41 +444,115 @@ class TOSAQuantizer(Quantizer): """Manage quantization annotations for TOSA-compatible backends.""" def __init__( - self, compile_spec_or_tosa_spec: TosaSpecification | ArmCompileSpec + self, + compile_spec_or_tosa_spec, + use_composable_quantizer: bool = False, ) -> None: - super().__init__() - self.compile_spec: ArmCompileSpec - if isinstance(compile_spec_or_tosa_spec, TosaSpecification): - from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec + """Create a TOSA quantizer from a TOSA spec or Arm compile spec.""" + self.use_composable_quantizer = use_composable_quantizer + self.quantizer: _TOSAQuantizerV1 | _TOSAQuantizerV2 + if use_composable_quantizer: + logger.info( + "Using composable quantizer implementation in the arm backend. See https://github.com/pytorch/executorch/issues/17701" + ) + self.quantizer = _TOSAQuantizerV2(compile_spec_or_tosa_spec) + else: + logger.info( + "Using default quantizer in the arm backend. This quantizer is planned to be replaced by the composable quantizer implementation in the future, see https://github.com/pytorch/executorch/issues/17701" + ) + self.quantizer = _TOSAQuantizerV1(compile_spec_or_tosa_spec) - self.compile_spec = TosaCompileSpec(compile_spec_or_tosa_spec) - self.tosa_spec = self.compile_spec.tosa_spec - elif isinstance(compile_spec_or_tosa_spec, ArmCompileSpec): - self.compile_spec = compile_spec_or_tosa_spec - self.tosa_spec = self.compile_spec.tosa_spec + @property + def tosa_spec(self): + return self.quantizer.tosa_spec + + @property + def compile_spec(self): + return self.quantizer.compile_spec + + @property + def global_config(self): + return self.quantizer.global_config + + @global_config.setter + def global_config(self, value: Optional[QuantizationConfig]) -> None: + if isinstance(self.quantizer, _TOSAQuantizerV1): + self.quantizer.global_config = value else: - raise TypeError( - f"TOSAQuantizer constructor expects " - f"a TosaSpecification or compile_spec list, " - f"got {type(compile_spec_or_tosa_spec)}" + raise NotImplementedError( + "Composable quantizer does not allow setting global_config directly. Please use set_global() instead." ) - self.global_config: Optional[QuantizationConfig] = None - self.io_config: Optional[QuantizationConfig] = None - self.module_type_config: Dict[Callable, Optional[QuantizationConfig]] = {} - self.module_name_config: Dict[str, Optional[QuantizationConfig]] = {} + @property + def io_config(self): + if isinstance(self.quantizer, _TOSAQuantizerV1): + return self.quantizer.io_config + else: + raise NotImplementedError( + "Composable quantizer does not allow accessing io_config." + ) + + @io_config.setter + def io_config(self, value: Optional[QuantizationConfig]) -> None: + if isinstance(self.quantizer, _TOSAQuantizerV1): + self.quantizer.io_config = value + else: + raise NotImplementedError( + "Composable quantizer does not allow setting io_config directly. Please use set_io() instead." + ) + + @property + def module_type_config(self): + if isinstance(self.quantizer, _TOSAQuantizerV1): + return self.quantizer.module_type_config + else: + raise NotImplementedError( + "Composable quantizer does not allow accessing module_type_config." + ) + + @module_type_config.setter + def module_type_config( + self, value: Dict[Callable, Optional[QuantizationConfig]] + ) -> None: + if isinstance(self.quantizer, _TOSAQuantizerV1): + self.quantizer.module_type_config = value + else: + raise NotImplementedError( + "Composable quantizer does not allow setting module_type_config directly. Please use set_module_type() instead." + ) + + @property + def module_name_config(self): + if isinstance(self.quantizer, _TOSAQuantizerV1): + return getattr(self.quantizer, "module_name_config", {}) + else: + raise NotImplementedError( + "Composable quantizer does not allow accessing module_name_config." + ) + + @module_name_config.setter + def module_name_config( + self, value: Dict[str, Optional[QuantizationConfig]] + ) -> None: + if isinstance(self.quantizer, _TOSAQuantizerV1): + self.quantizer.module_name_config = value + else: + raise NotImplementedError( + "Composable quantizer does not allow setting module_name_config directly. Please use set_module_name() instead." + ) def set_global( - self, quantization_config: QuantizationConfig | None + self, quantization_config: Optional[QuantizationConfig] ) -> TOSAQuantizer: """Set quantization_config for submodules not matched by other filters. Args: - quantization_config (QuantizationConfig): Configuration to apply to - modules that are not captured by name or type filters. + quantization_config (Optional[QuantizationConfig]): Configuration to + apply to modules that are not captured by name or type filters. + ``None`` indicates no quantization. """ - self.global_config = quantization_config + self.quantizer.set_global(quantization_config) return self def set_module_type( @@ -401,17 +560,18 @@ def set_module_type( ) -> TOSAQuantizer: """Set quantization_config for submodules with a given module type. - For example, calling set_module_type(Sub) quantizes supported patterns - in each Sub instance with the provided quantization_config. + For example, calling set_module_type(Softmax) quantizes supported + patterns in each Softmax instance with the provided quantization_config. Args: module_type (Callable): Type whose submodules should use the provided quantization configuration. - quantization_config (QuantizationConfig): Configuration to apply to - submodules of the given type. + quantization_config (Optional[QuantizationConfig]): Configuration to + apply to submodules of the given type. ``None`` indicates no + quantization. """ - self.module_type_config[module_type] = quantization_config + self.quantizer.set_module_type(module_type, quantization_config) return self def set_module_name( @@ -424,22 +584,266 @@ def set_module_name( Args: module_name (str): Fully qualified module name to configure. - quantization_config (QuantizationConfig): Configuration applied to - the named submodule. + quantization_config (Optional[QuantizationConfig]): Configuration + applied to the named submodule. ``None`` indicates no + quantization. """ - # Validate that quantization_config is provided - self.module_name_config[module_name] = quantization_config + self.quantizer.set_module_name(module_name, quantization_config) return self - def set_io(self, quantization_config: QuantizationConfig) -> TOSAQuantizer: + def set_io( + self, quantization_config: Optional[QuantizationConfig] + ) -> TOSAQuantizer: """Set quantization_config for input and output nodes. Args: - quantization_config (QuantizationConfig): Configuration describing - activation quantization for model inputs and outputs. + quantization_config (Optional[QuantizationConfig]): Configuration + describing activation quantization for model inputs and outputs. + ``None`` indicates no quantization. + + """ + self.quantizer.set_io(quantization_config) + return self + + def add_quantizer(self, quantizer: Quantizer) -> TOSAQuantizer: + """Insert a quantizer with highest precedence.""" + if self.use_composable_quantizer: + return self.quantizer.add_quantizer(quantizer) # type: ignore[union-attr,return-value] + raise NotImplementedError( + "add_quantizer is only supported in the composable quantizer implementation." + ) + + def set_node_finder( + self, quantization_config: Optional[QuantizationConfig], node_finder: NodeFinder + ) -> TOSAQuantizer: + """Set quantization_config for nodes matched by a custom NodeFinder. + + Args: + quantization_config (Optional[QuantizationConfig]): Configuration + describing quantization settings for nodes matched by the provided + NodeFinder. ``None`` indicates no quantization. """ + if self.use_composable_quantizer: + return self.quantizer.set_node_finder(quantization_config, node_finder) # type: ignore[union-attr,return-value] + raise NotImplementedError( + "set_node_finder is only supported in the composable quantizer implementation." + ) + + def set_node_target( + self, node_target: OpOverload, quantization_config: Optional[QuantizationConfig] + ) -> TOSAQuantizer: + """Set quantization config for a specific operator target.""" + if self.use_composable_quantizer: + return self.quantizer.set_node_target(node_target, quantization_config) # type: ignore[union-attr,return-value] + raise NotImplementedError( + "set_node_target is only supported in the composable quantizer implementation." + ) + + def set_node_name( + self, node_name: str, quantization_config: Optional[QuantizationConfig] + ) -> TOSAQuantizer: + """Set quantization config for a specific node name.""" + if self.use_composable_quantizer: + return self.quantizer.set_node_name(node_name, quantization_config) # type: ignore[union-attr,return-value] + raise NotImplementedError( + "set_node_name is only supported in the composable quantizer implementation." + ) + + def transform_for_annotation(self, model: GraphModule) -> GraphModule: + """Transform the graph to prepare it for quantization annotation. + + Decomposes all operators where required to get correct quantization parameters. + + Args: + model (GraphModule): Model whose graph will be transformed. + + Returns: + GraphModule: Transformed model prepared for annotation. + + """ + return self.quantizer.transform_for_annotation(model) + + def annotate(self, model: GraphModule) -> GraphModule: + """Annotate the graph with the configured quantization settings. + + Currently only does static quantization annotation. + + Args: + model (GraphModule): Model to annotate statically. + + Returns: + GraphModule: Annotated model ready for export. + + """ + return self.quantizer.annotate(model) + + def validate(self, model: GraphModule) -> None: + """Validate the quantization results. Currently, this includes: + - Ensure tensor inputs to each operator live on the same device. + + Args: + model (GraphModule): GraphModule being validated. + Raises: + ValueError: If tensor inputs for any operator span more than one + device. + """ + for node in model.graph.nodes: + if node.op != "call_function": + continue + + devices = set() + for arg_node in node.all_input_nodes: + meta_val = arg_node.meta.get("val", None) + if meta_val is None: + continue + if isinstance(meta_val, (tuple, list)): + for tensor in meta_val: + devices.add( + str( + getattr( + tensor, + "device", + f"Could not get device from {tensor}", + ) + ) + ) + else: + devices.add( + str( + getattr( + meta_val, + "device", + f"Could not get device from {meta_val}", + ) + ) + ) + + if len(devices) > 1: + raise ValueError( + f"Quantizer detected operator {node.name} with different device inputs: {devices}." + ) + + def quantize_with_submodules( + self, + model: GraphModule, + calibration_samples: list[tuple], + is_qat: bool = False, + fold_quantize: bool = True, + ): + """Quantizes a GraphModule in a way such that conditional submodules are + handled properly. + + Note: torchao's prepare_pt2e and convert_pt2e natively handle + while_loop body_fn submodules, so we only manually process cond + branches and while_loop cond_fn here. + + Args: + model (GraphModule): The model to quantize. + calibration_samples (list[tuple]): A list of inputs to used to + calibrate the model during quantization. To properly calibrate a + model with submodules, at least one sample per code path is + needed. + is_qat (bool): Whether to do quantization aware training or not. + fold_quantize (bool): Enables or disables constant folding when quantization + is completed. + + Returns: + GraphModule: The quantized model. + + """ + prepare_fn = prepare_qat_pt2e if is_qat else prepare_pt2e + + prepared = prepare_fn(model, self) + # Prepare conditional submodules (e.g., if/while bodies) + # prepare only cond branches and while_loop cond_fn + for name, submodule, _ in get_cond_while_submodules_nested( + prepared, apply_quantization=True + ): + prepared.set_submodule(name, prepare_fn(submodule, self), strict=True) + for submodule_node in submodule.graph.nodes: + if is_submodule_node(submodule_node): + for nested_name, nested_sub, _ in get_cond_while_submodules_nested( + submodule, apply_quantization=True + ): + prepared.set_submodule( + nested_name, prepare_fn(nested_sub, self), strict=True + ) + + for inp in calibration_samples: + prepared(*inp) + + # Prepare conditional submodules (e.g., if/while bodies) + # convert only cond branches and while_loop cond_fn + for _, submodule, _ in get_cond_while_submodules_nested( + prepared, apply_quantization=True + ): + converted = convert_pt2e(submodule) + for submodule_node in submodule.graph.nodes: + if is_submodule_node(submodule_node): + for nested_name, nested_sub, _ in get_cond_while_submodules_nested( + submodule, apply_quantization=True + ): + converted.set_submodule( + nested_name, convert_pt2e(nested_sub), strict=True + ) + + return convert_pt2e(prepared) + + +class _TOSAQuantizerV1(Quantizer): + + def __init__( + self, compile_spec_or_tosa_spec: TosaSpecification | ArmCompileSpec + ) -> None: + super().__init__() + self.compile_spec: ArmCompileSpec + if isinstance(compile_spec_or_tosa_spec, TosaSpecification): + from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec + + self.compile_spec = TosaCompileSpec(compile_spec_or_tosa_spec) + self.tosa_spec = self.compile_spec.tosa_spec + elif isinstance(compile_spec_or_tosa_spec, ArmCompileSpec): + self.compile_spec = compile_spec_or_tosa_spec + self.tosa_spec = self.compile_spec.tosa_spec + else: + raise TypeError( + f"TOSAQuantizer constructor expects " + f"a TosaSpecification or compile_spec list, " + f"got {type(compile_spec_or_tosa_spec)}" + ) + + self.global_config: Optional[QuantizationConfig] = None + self.io_config: Optional[QuantizationConfig] = None + self.module_type_config: Dict[Callable, Optional[QuantizationConfig]] = {} + self.module_name_config: Dict[str, Optional[QuantizationConfig]] = {} + + def set_global( + self, quantization_config: Optional[QuantizationConfig] + ) -> _TOSAQuantizerV1: + + self.global_config = quantization_config + return self + + def set_module_type( + self, module_type: Callable, quantization_config: Optional[QuantizationConfig] + ) -> _TOSAQuantizerV1: + + self.module_type_config[module_type] = quantization_config + return self + + def set_module_name( + self, module_name: str, quantization_config: Optional[QuantizationConfig] + ) -> _TOSAQuantizerV1: + + # Validate that quantization_config is provided + self.module_name_config[module_name] = quantization_config + return self + + def set_io( + self, quantization_config: Optional[QuantizationConfig] + ) -> _TOSAQuantizerV1: self.io_config = quantization_config return self @@ -470,38 +874,12 @@ def _set_disallow_tfa_for_nodes(self, model: GraphModule) -> None: node.meta[DISALLOW_TFA_META_KEY] = config is None def transform_for_annotation(self, model: GraphModule) -> GraphModule: - """Transform the graph to prepare it for quantization annotation. - - Currently transforms scalar values to tensor attributes. - - Args: - model (GraphModule): Model whose graph will be transformed. - - Returns: - GraphModule: Transformed model prepared for annotation. - - """ - self._set_disallow_tfa_for_nodes(model) - # TODO: Fix the need to lazily import this. - from executorch.backends.arm._passes import ArmPassManager - pass_manager = ArmPassManager(self.compile_spec) return pass_manager.transform_for_annotation_pipeline(graph_module=model) def annotate(self, model: GraphModule) -> GraphModule: - """Annotate the graph with the configured quantization settings. - - Currently only does static quantization annotation. - - Args: - model (GraphModule): Model to annotate statically. - - Returns: - GraphModule: Annotated model ready for export. - - """ model = self._annotate_for_static_quantization_config(model) return model @@ -598,116 +976,192 @@ def _annotate_io( mark_node_as_annotated(node) def validate(self, model: GraphModule) -> None: - """Validate the quantization results. Currently, this includes: - - Ensure tensor inputs to each operator live on the same device. + # Validation is handled by TOSAQuantizer.validate; keep no-op for + # Quantizer interface compatibility. + return None + + +class _TOSAQuantizerV2(ComposableQuantizer): + + def __init__( + self, compile_spec_or_tosa_spec: TosaSpecification | ArmCompileSpec + ) -> None: + self.compile_spec: ArmCompileSpec + if isinstance(compile_spec_or_tosa_spec, TosaSpecification): + from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec + + self.compile_spec = TosaCompileSpec(compile_spec_or_tosa_spec) + self.tosa_spec = self.compile_spec.tosa_spec + elif isinstance(compile_spec_or_tosa_spec, ArmCompileSpec): + self.compile_spec = compile_spec_or_tosa_spec + self.tosa_spec = self.compile_spec.tosa_spec + else: + raise TypeError( + f"TOSAQuantizer constructor expects " + f"a TosaSpecification or compile_spec list, " + f"got {type(compile_spec_or_tosa_spec)}" + ) + + self.pattern_matcher = PatternMatcher(TOSA_QUANTIZER_SUPPORT_DICT) + self.shared_qspec_quantizer = SharedQspecQuantizer() + self.global_quantizer: Quantizer | None = None + self.global_config: Optional[QuantizationConfig] = None + self._quantizers: List[Quantizer] = [] + self._graph_annotations: dict[Node, QuantizationAnnotation] = {} + + @property + def quantizers(self) -> List[Quantizer]: + """Returns the configured quantizers in order of precedence, ensuring + the global config and shared_qspec_quantizer are applied last. + + The returned list is a shallow copy; quantizer instances are shared. - Args: - model (GraphModule): GraphModule being validated. - Raises: - ValueError: If tensor inputs for any operator span more than one - device. """ + quantizers = self._quantizers.copy() + if self.global_quantizer is not None: + quantizers.append(self.global_quantizer) + quantizers.append(self.shared_qspec_quantizer) + + return quantizers + + @quantizers.setter + def quantizers(self, value: List[Quantizer]) -> None: + """Override of quantizers setter to allow for dynamic updating of + quantizers without accessing self._quantizers. + """ + self._quantizers = value + + def annotate(self, model): + reporter = QuantizerReporter(self.quantizers, "FINAL QUANTIZATION REPORT") + model = super().annotate(model) + reporter.log_quantizer_report(model) + return model + + def _remove_annotations(self, model: GraphModule) -> GraphModule: for node in model.graph.nodes: - if node.op != "call_function": - continue + if Q_ANNOTATION_KEY in node.meta: + del node.meta[Q_ANNOTATION_KEY] + if ArmAnnotationInfo.CUSTOM_META_KEY in node.meta: + del node.meta[ArmAnnotationInfo.CUSTOM_META_KEY] + if DISALLOW_TFA_META_KEY in node.meta: + del node.meta[DISALLOW_TFA_META_KEY] + if PatternMatcher.Q_PATTERN_MATCHED_KEY in node.meta: + del node.meta[PatternMatcher.Q_PATTERN_MATCHED_KEY] + + # Clear quantizer internal annotation tracking + self._graph_annotations.clear() - devices = set() - for arg_node in node.all_input_nodes: - meta_val = arg_node.meta.get("val", None) - if meta_val is None: - continue - if isinstance(meta_val, (tuple, list)): - for tensor in meta_val: - devices.add( - str( - getattr( - tensor, - "device", - f"Could not get device from {tensor}", - ) - ) - ) - else: - devices.add( - str( - getattr( - meta_val, - "device", - f"Could not get device from {meta_val}", - ) - ) - ) + return model - if len(devices) > 1: - raise ValueError( - f"Quantizer detected operator {node.name} with different device inputs: {devices}." - ) + def transform_for_annotation(self, model: GraphModule) -> GraphModule: + # Transform_for_annotation should only decompose ops if quantized, which is + # indicated either by node.meta['DISALLOW_TFA_META_KEY']==False or no such key + # existing in the dict. This means that ops are assumed to be quantized by + # default and we need to explicitly annotate all non-quantized nodes with + # DISALLOW_TFA_META_KEY=True before calling the pass manager. + + # For _TOSAQuantizerV2 there is no simple filter which directly finds unquantized + # nodes since nodes can be annotated by any quantizer. Instead, self.annotate is + # run to set DISALLOW_TFA_META_KEY for quantized nodes and all nodes missing + # this key afterwards are set to DISALLOW_TFA_META_KEY=True. + + reporter = QuantizerReporter( + self.quantizers, "PRE-TRANSFORM_FOR_ANNOTATION QUANTIZATION REPORT" # type: ignore[arg-type] + ) + model = super().annotate(model) + reporter.log_quantizer_report(model) + for node in model.graph.nodes: + if DISALLOW_TFA_META_KEY not in node.meta: + node.meta[DISALLOW_TFA_META_KEY] = True - def quantize_with_submodules( - self, - model: GraphModule, - calibration_samples: list[tuple], - is_qat: bool = False, - fold_quantize: bool = True, - ): - """Quantizes a GraphModule in a way such that conditional submodules are - handled properly. + pass_manager = ArmPassManager(self.compile_spec) + transformed_model = pass_manager.transform_for_annotation_pipeline(model) - Note: torchao's prepare_pt2e and convert_pt2e natively handle - while_loop body_fn submodules, so we only manually process cond - branches and while_loop cond_fn here. + # Remove the temporary annotations + return self._remove_annotations(transformed_model) - Args: - model (GraphModule): The model to quantize. - calibration_samples (list[tuple]): A list of inputs to used to - calibrate the model during quantization. To properly calibrate a - model with submodules, at least one sample per code path is - needed. - is_qat (bool): Whether to do quantization aware training or not. - fold_quantize (bool): Enables or disables constant folding when quantization - is completed. + def add_quantizer(self, quantizer: Quantizer) -> _TOSAQuantizerV2: + """Insert a quantizer with highest precedence.""" + self._quantizers.insert(0, quantizer) + return self - Returns: - GraphModule: The quantized model. + def set_node_finder( + self, quantization_config: Optional[QuantizationConfig], node_finder: NodeFinder + ) -> _TOSAQuantizerV2: + """Add a quantizer targeting nodes found by the provided finder. + + ``None`` indicates no quantization for matched nodes. """ - prepare_fn = prepare_qat_pt2e if is_qat else prepare_pt2e + quantizer = PatternQuantizer( + quantization_config, node_finder, self.pattern_matcher + ) + self.add_quantizer(quantizer) + return self - prepared = prepare_fn(model, self) - # Prepare conditional submodules (e.g., if/while bodies) - # prepare only cond branches and while_loop cond_fn - for name, submodule, _ in get_cond_while_submodules_nested( - prepared, apply_quantization=True - ): - prepared.set_submodule(name, prepare_fn(submodule, self), strict=True) - for submodule_node in submodule.graph.nodes: - if is_submodule_node(submodule_node): - for nested_name, nested_sub, _ in get_cond_while_submodules_nested( - submodule, apply_quantization=True - ): - prepared.set_submodule( - nested_name, prepare_fn(nested_sub, self), strict=True - ) + def set_global( + self, quantization_config: Optional[QuantizationConfig] + ) -> _TOSAQuantizerV2: + """Set the default quantization config for all nodes. - for inp in calibration_samples: - prepared(*inp) + ``None`` indicates no quantization. - # Prepare conditional submodules (e.g., if/while bodies) - # convert only cond branches and while_loop cond_fn - for _, submodule, _ in get_cond_while_submodules_nested( - prepared, apply_quantization=True - ): - converted = convert_pt2e(submodule) - for submodule_node in submodule.graph.nodes: - if is_submodule_node(submodule_node): - for nested_name, nested_sub, _ in get_cond_while_submodules_nested( - submodule, apply_quantization=True - ): - converted.set_submodule( - nested_name, convert_pt2e(nested_sub), strict=True - ) + """ + node_finder = GlobalNodeFinder() + self.global_quantizer = PatternQuantizer( + quantization_config, node_finder, self.pattern_matcher + ) + self.global_config = quantization_config + return self - return convert_pt2e(prepared) + def set_node_target( + self, node_target: OpOverload, quantization_config: Optional[QuantizationConfig] + ) -> _TOSAQuantizerV2: + """Set quantization config for a specific operator target.""" + node_finder = NodeTargetNodeFinder(node_target) + self.set_node_finder(quantization_config, node_finder) + return self + + def set_node_name( + self, node_name: str, quantization_config: Optional[QuantizationConfig] + ) -> _TOSAQuantizerV2: + """Set quantization config for a specific node name.""" + node_finder = NodeNameNodeFinder(node_name) + self.set_node_finder(quantization_config, node_finder) + return self + + def set_module_type( + self, module_type: Callable, quantization_config: Optional[QuantizationConfig] + ) -> _TOSAQuantizerV2: + """Set quantization config for nodes originating from a module type.""" + node_finder = ModuleTypeNodeFinder(module_type) + self.set_node_finder(quantization_config, node_finder) + return self + + def set_module_name( + self, module_name: str, quantization_config: Optional[QuantizationConfig] + ) -> _TOSAQuantizerV2: + """Set quantization config for nodes originating from a module name.""" + node_finder = ModuleNameNodeFinder(module_name) + self.set_node_finder(quantization_config, node_finder) + return self + + def set_io( + self, quantization_config: Optional[QuantizationConfig] + ) -> _TOSAQuantizerV2: + """Set quantization_config for input and output nodes. + + Args: + quantization_config (Optional[QuantizationConfig]): Configuration + describing activation quantization for model inputs and outputs. + ``None`` indicates no quantization. + + """ + input_finder = InputNodeFinder() + output_finder = OutputNodeFinder() + self.set_node_finder(quantization_config, input_finder) + self.set_node_finder(quantization_config, output_finder) + return self class EthosUQuantizer(TOSAQuantizer): @@ -716,11 +1170,16 @@ class EthosUQuantizer(TOSAQuantizer): Args: compile_spec (EthosUCompileSpec): Backend compile specification for Ethos-U targets. + use_composable_quantizer (bool): Whether to use the composable quantizer implementation. See https://github.com/pytorch/executorch/issues/17701" for details. """ - def __init__(self, compile_spec: EthosUCompileSpec) -> None: - super().__init__(compile_spec) + def __init__( + self, + compile_spec: EthosUCompileSpec, + use_composable_quantizer: bool = False, + ) -> None: + super().__init__(compile_spec, use_composable_quantizer) class VgfQuantizer(TOSAQuantizer): @@ -729,8 +1188,13 @@ class VgfQuantizer(TOSAQuantizer): Args: compile_spec (VgfCompileSpec): Backend compile specification for Vgf targets. + use_composable_quantizer (bool): Whether to use the composable quantizer implementation. See https://github.com/pytorch/executorch/issues/17701" for details. """ - def __init__(self, compile_spec: VgfCompileSpec) -> None: - super().__init__(compile_spec) + def __init__( + self, + compile_spec: VgfCompileSpec, + use_composable_quantizer: bool = False, + ) -> None: + super().__init__(compile_spec, use_composable_quantizer) diff --git a/backends/arm/quantizer/arm_quantizer_utils.py b/backends/arm/quantizer/arm_quantizer_utils.py index 68b855c5607..3deb9d00741 100644 --- a/backends/arm/quantizer/arm_quantizer_utils.py +++ b/backends/arm/quantizer/arm_quantizer_utils.py @@ -11,15 +11,32 @@ """ -from typing import cast +import logging +import operator +from abc import ABC, abstractmethod +from typing import Any, Callable, cast, Iterator, Optional, TYPE_CHECKING -from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo +import torch +from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo +from executorch.backends.arm.constants import DISALLOW_TFA_META_KEY +from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig from torch.fx import Node -from torchao.quantization.pt2e.quantizer import QuantizationAnnotation +from torchao.quantization.pt2e.quantizer import ( + DerivedQuantizationSpec, + QuantizationAnnotation, + QuantizationSpec, + Quantizer, + SharedQuantizationSpec, +) from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from executorch.backends.cortex_m.quantizer.pattern_matcher import PatternMatcher + def is_annotated(node: Node) -> bool: """Return True if the node is annotated. @@ -73,3 +90,556 @@ def mark_node_as_annotated(node: Node) -> None: meta_custom = node.meta.get("custom", {}) meta_custom[ArmAnnotationInfo.CUSTOM_META_KEY] = dict(annotation_info) node.meta["custom"] = meta_custom + + +def has_float_output(node: Node) -> bool: + meta_val = node.meta.get("val", None) + if isinstance(meta_val, torch.Tensor): + return meta_val.dtype.is_floating_point + return False + + +def _mark_node_as_quantized( + node: Node, + input_qspec_map, + output_qspec, + is_quantized, +) -> None: + """Fill metadata fields used for quantization, partitioning, and + lowering. + """ + node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map, output_qspec, _annotated=True + ) + + if node.op == "call_function": + meta_custom = node.meta.get("custom", {}) + meta_custom[ArmAnnotationInfo.CUSTOM_META_KEY] = ArmAnnotationInfo( + quantized=is_quantized + ) + node.meta["custom"] = meta_custom + + node.meta[DISALLOW_TFA_META_KEY] = not is_quantized + + +def _derive_bias_qparams_fn( + obs_or_fqs, +) -> tuple[torch.Tensor, torch.Tensor]: + if len(obs_or_fqs) != 2: + raise ValueError( + f"Expecting two obs/fqs, one for activation and one for weight, got: {len(obs_or_fqs)}" + ) + act_obs_or_fq = obs_or_fqs[0] + weight_obs_or_fq = obs_or_fqs[1] + act_scale, _ = act_obs_or_fq.calculate_qparams() + weight_scale, _ = weight_obs_or_fq.calculate_qparams() + return act_scale * weight_scale, torch.full_like( + weight_scale, fill_value=0, dtype=torch.int32 + ) + + +def _get_int32_bias_qspec(node): + return DerivedQuantizationSpec( + derived_from=((node.args[0], node), (node.args[1], node)), # type: ignore[list-item] + derive_qparams_fn=_derive_bias_qparams_fn, + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max - 1, + ) + + +def _get_int32_per_channel_bias_qspec(node): + return DerivedQuantizationSpec( + derived_from=((node.args[0], node), (node.args[1], node)), # type: ignore[list-item] + derive_qparams_fn=_derive_bias_qparams_fn, + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max - 1, + qscheme=torch.per_channel_symmetric, + ch_axis=0, + ) + + +class _QuantizerReporterUserMixin: + def __init__(self): + self.reporter = None + + def register_reporter(self, reporter) -> None: + self.reporter = reporter + + def report_reject(self, pattern: list[Node], reason: str) -> None: + if self.reporter is not None: + self.reporter.report_reject(self, pattern, reason) + + def report_accept(self, pattern: list[Node]) -> None: + if self.reporter is not None: + self.reporter.report_accept(self, pattern) + + def get_quantizer_info(self): + raise NotImplementedError("Quantizer must implement get_quantizer_info method.") + + +class PatternCheck: + """Base class for pattern checks. + + PatternChecks are used to define which patterns are supported for + quantization and to validate quantization configuration constraints. + + """ + + @classmethod + def is_per_tensor(cls, qspec) -> bool: + from torchao.quantization.pt2e.quantizer import QuantizationSpecBase + + if not isinstance(qspec, QuantizationSpecBase): + return False + return qspec.qscheme in ( # type: ignore[attr-defined] + torch.per_tensor_affine, + torch.per_tensor_symmetric, + ) + + @classmethod + def is_per_channel(cls, qspec) -> bool: + from torchao.quantization.pt2e.quantizer import QuantizationSpecBase + + if not isinstance(qspec, QuantizationSpecBase): + return False + return qspec.qscheme in ( # type: ignore[attr-defined] + torch.per_channel_affine, + torch.per_channel_symmetric, + ) + + @classmethod + def is_int8_activations( + cls, qconfig: QuantizationConfig, output_node: Node | None = None + ) -> bool: + input_qspec = qconfig.get_input_act_qspec() + output_qspec = qconfig.get_output_act_qspec(output_node) + from torchao.quantization.pt2e.quantizer import QuantizationSpecBase + + if not isinstance(input_qspec, QuantizationSpecBase) or not isinstance( + output_qspec, QuantizationSpecBase + ): + return False + return ( + input_qspec.dtype == torch.int8 and output_qspec.dtype == torch.int8 # type: ignore[attr-defined] + ) + + @classmethod + def check_pattern(cls, pattern: list[Node]) -> bool: + return True + + @classmethod + def check_quantization_config( + cls, pattern: list[Node], quantization_config: QuantizationConfig + ) -> bool: + return True + + +class NodeFinder(ABC): + @abstractmethod + def find_nodes(self, model: torch.fx.GraphModule) -> Iterator[Node]: + """Return nodes of the graph module depending on NodeFinder type. + + Args: + model (GraphModule): The graph module to search for matching nodes. + + """ + pass + + +class PatternQuantizer(Quantizer, _QuantizerReporterUserMixin): + """Quantizes a graph according to an OperatorConfig. + + Args: + quantization_config (QuantizationConfig): The quantization config to use for annotation. + node_finder (NodeFinder): The node finder to use for finding nodes to match patterns. + pattern_matcher (PatternMatcher): The pattern matcher to use for finding patterns in the nodes. + + """ + + def __init__( + self, + quantization_config: QuantizationConfig | None, + node_finder: "NodeFinder", + pattern_matcher: "PatternMatcher", + ) -> None: + super().__init__() + _QuantizerReporterUserMixin.__init__(self) + self.quantization_config: QuantizationConfig | None = quantization_config + self.node_finder: "NodeFinder" = node_finder + self.pattern_matcher: "PatternMatcher" = pattern_matcher + + def get_quantizer_info(self): + from executorch.backends.cortex_m.quantizer.quantizer_reporter import ( + QuantizerInfo, + SUPPORTED_QCONFIGS, + ) + + name = self.__class__.__name__ + targeted_nodes_description = str(self.node_finder) + quantization_config_path = SUPPORTED_QCONFIGS.get( + self.quantization_config, "UNREGISTERED_QCONFIG" + ) + support_config_path = self.pattern_matcher.support_dict_name + + return QuantizerInfo( + name, + targeted_nodes_description, + quantization_config_path, + support_config_path, + ) + + def is_parameter(self, node: Node, model: torch.fx.GraphModule) -> bool: + """Returns True if the given node is a parameter of the model.""" + try: + _ = model.get_parameter(node.target) # type: ignore[arg-type] + return True + except Exception: + return False + + def is_weight( + self, node: Node, params: list[Node], model: torch.fx.GraphModule + ) -> bool: + """Returns True if node is the first parameter of the given + parameters. + """ + return len(params) > 0 and node == params[0] + + def is_bias( + self, node: Node, params: list[Node], model: torch.fx.GraphModule + ) -> bool: + """Returns True if node is the second parameter of the given + parameters. + """ + return len(params) == 2 and node == params[1] + + def annotate_match( + self, + match: list[Node], + config: QuantizationConfig | None, + model: torch.fx.GraphModule, + ) -> None: + """Annotates a matched pattern according to the given quantization + config. + """ + parameter_targets = { + torch.ops.aten.linear.default, + torch.ops.aten.convolution.default, + torch.ops.aten.conv1d.default, + torch.ops.aten.conv1d.padding, + torch.ops.aten.conv2d.default, + torch.ops.aten.conv2d.padding, + torch.ops.aten.conv3d.default, + torch.ops.aten.conv3d.padding, + torch.ops.aten.conv_transpose2d.input, + } + + for node in match: + input_qspec_map = {} + output_qspec = None + + params = [n for n in node.all_input_nodes if self.is_parameter(n, model)] + if node.target in parameter_targets: + if len(params) == 0 or len(params) > 2: + logger.warning( + f"{node.name} is expected to have parameter tensors for weight/bias but no such inputs found, which may cause unexpected quantization annotations. This is likely caused by incorrect tensor instantiations or non-constant weight/biases." + ) + else: + if len(params) > 0: + logger.warning( + f"{node.name} is not expected to not have parameter tensors but found {[n.name for n in params]}, which may cause unexpected quantization annotations." + ) + + for input_node in node.all_input_nodes: + if not has_float_output(input_node): + continue + if self.is_weight(input_node, params, model): + input_qspec_map[input_node] = ( + config.get_weight_qspec(node) if config else None + ) + elif self.is_bias(input_node, params, model): + input_qspec_map[input_node] = ( + config.get_bias_qspec(node) if config else None # type: ignore[assignment] + ) + elif input_node not in match: + input_qspec_map[input_node] = ( + config.get_input_act_qspec(node, input_node) if config else None + ) + + if all(node not in match for node in node.users) and output_qspec is None: + if has_float_output(node): + output_qspec = config.get_output_act_qspec(node) if config else None + + _mark_node_as_quantized( + node, + input_qspec_map, + output_qspec, + config is not None, + ) + + def annotate(self, model: torch.fx.GraphModule) -> None: # type: ignore[override] + nodes = self.node_finder.find_nodes(model) + matches = self.pattern_matcher.find_pattern_matches( + nodes, self.quantization_config # type: ignore[arg-type] + ) + for result in matches: + if result.accepted: + self.annotate_match(result.pattern, self.quantization_config, model) + self.report_accept(result.pattern) + else: + self.report_reject( + result.pattern, + result.message or "Pattern rejected.", + ) + + def validate(self, model: torch.fx.GraphModule) -> bool: # type: ignore[override] + return True + + +class SharedQspecQuantizer(Quantizer, _QuantizerReporterUserMixin): + """Assures that specific ops share quantization parameters on all + inputs/outputs. + """ + + SHARED_QSPEC_OPS_DEFAULT: list[Callable[..., object]] = [ + torch.ops.aten.clone.default, + torch.ops.aten.lift_fresh_copy.default, + torch.ops.aten.detach_.default, + torch.ops.aten.alias.default, + torch.ops.aten.alias_copy.default, + torch.ops.aten.copy_.default, + torch.ops.aten.detach_copy.default, + torch.ops.aten.unfold_copy.default, + torch.ops.aten.unbind.int, + torch.ops.aten.minimum.default, + torch.ops.aten.maximum.default, + torch.ops.aten.min.dim, + torch.ops.aten.max.dim, + torch.ops.aten.amin.default, + torch.ops.aten.amax.default, + torch.ops.aten.permute.default, + torch.ops.aten.permute_copy.default, + torch.ops.aten.transpose.int, + torch.ops.aten.transpose_copy.int, + torch.ops.aten.t_copy.default, + torch.ops.aten.t.default, + torch.ops.aten.repeat.default, + torch.ops.aten.repeat_interleave.self_int, + torch.ops.aten.expand_copy.default, + torch.ops.aten.expand.default, + torch.ops.aten.select.int, + torch.ops.aten.select_copy.int, + torch.ops.aten.slice.Tensor, + torch.ops.aten.slice_copy.Tensor, + torch.ops.aten.split.Tensor, + torch.ops.aten.split_with_sizes.default, + torch.ops.aten.split_copy.Tensor, + torch.ops.aten.tile.default, + torch.ops.aten.flip.default, + torch.ops.aten.index_select.default, + torch.ops.aten.index_put.default, + torch.ops.aten.contiguous.default, + torch.ops.aten.as_strided_copy.default, + torch.ops.aten.pixel_shuffle.default, + torch.ops.aten.pixel_unshuffle.default, + torch.ops.aten.cat.default, + torch.ops.aten.concatenate.default, + torch.ops.aten.stack.default, + torch.ops.aten.dropout.default, + torch.ops.aten.dropout_.default, + torch.ops.aten.chunk.default, + torch.ops.aten.index.Tensor, + torch.ops.aten.gather.default, + operator.getitem, + torch.ops.aten.squeeze.default, + torch.ops.aten.squeeze_copy.default, + torch.ops.aten.squeeze_copy.dim, + torch.ops.aten.squeeze.dim, + torch.ops.aten.squeeze.dims, + torch.ops.aten.squeeze_.dim, + torch.ops.aten.unsqueeze.default, + torch.ops.aten.unsqueeze_copy.default, + torch.ops.aten.reshape.default, + torch.ops.aten.view.default, + torch.ops.aten.view_as.default, + torch.ops.aten.view_copy.default, + torch.ops.aten._unsafe_view.default, + torch.ops.aten.unflatten.int, + torch.ops.aten.flatten.using_ints, + torch.ops.aten.pad.default, + torch.ops.aten.constant_pad_nd.default, + torch.ops.aten.clamp.default, + torch.ops.aten.clamp.Tensor, + torch.ops.aten.hardtanh.default, + torch.ops.aten.hardtanh_.default, + torch.ops.aten.relu.default, + torch.ops.aten.relu_.default, + torch.ops.aten.eq.Tensor, + torch.ops.aten.eq.Scalar, + torch.ops.aten.ne.Tensor, + torch.ops.aten.ne.Scalar, + torch.ops.aten.ge.Tensor, + torch.ops.aten.ge.Scalar, + torch.ops.aten.gt.Tensor, + torch.ops.aten.gt.Scalar, + torch.ops.aten.le.Tensor, + torch.ops.aten.le.Scalar, + torch.ops.aten.lt.Tensor, + torch.ops.aten.lt.Scalar, + torch.ops.aten.where.self, + torch.ops.aten.where.default, + torch.ops.higher_order.while_loop, + torch.ops.higher_order.cond, + ] + + def __init__(self, targets: Optional[list[Callable[..., object]]] = None) -> None: + super().__init__() + _QuantizerReporterUserMixin.__init__(self) + if targets is None: + self.targets = self.SHARED_QSPEC_OPS_DEFAULT + self.support_config_path = ( + __name__ + f".{self.__class__.__name__}.SHARED_QSPEC_OPS_DEFAULT" + ) + else: + self.targets = targets + self.support_config_path = ( + f"CUSTOM TARGETS: {', '.join([str(target) for target in targets])}" + ) + + def get_quantizer_info(self): + from executorch.backends.cortex_m.quantizer.quantizer_reporter import ( + QuantizerInfo, + ) + + name = self.__class__.__name__ + targeted_nodes_description = "" + quantization_config_path = "SHARED_QCONFIG" + support_config_path = self.support_config_path + return QuantizerInfo( + name, + targeted_nodes_description, + quantization_config_path, + support_config_path, + ) + + def _is_annotated(self, node: Node) -> bool: + return Q_ANNOTATION_KEY in node.meta + + def _get_input_nodes_with_float_output(self, node: Node) -> list[Node]: + return [n for n in node.all_input_nodes if has_float_output(n)] + + def _get_user_nodes_with_float_input(self, node: Node) -> list[Node]: + return [n for n in node.users.keys() if has_float_output(node)] + + def _get_shared_clique(self, root_node: Node) -> tuple[set[Node], list[Any]]: + shared_nodes = set() + bfs_queue = [root_node] + adjacent_qspecs = [] + + while bfs_queue: + node = bfs_queue.pop(0) + shared_nodes.add(node) + + for input_node in node.all_input_nodes: + if input_node.target in self.targets and input_node not in shared_nodes: + if not self._is_annotated(input_node): + bfs_queue.append(input_node) + if self._is_annotated(input_node): + output_qspec = input_node.meta.get( # type: ignore[union-attr] + Q_ANNOTATION_KEY + ).output_qspec + if output_qspec is not None: + adjacent_qspecs.append(output_qspec) + + for output_node in node.users.keys(): + if ( + output_node.target in self.targets + and output_node not in shared_nodes + ): + if not self._is_annotated(output_node): + bfs_queue.append(output_node) + if ( + self._is_annotated(output_node) + and node + in output_node.meta.get( # type: ignore[union-attr] + Q_ANNOTATION_KEY + ).input_qspec_map + ): + input_qspec = output_node.meta.get( # type: ignore[union-attr] + Q_ANNOTATION_KEY + ).input_qspec_map[node] + if input_qspec is not None: + adjacent_qspecs.append(input_qspec) + + return shared_nodes, adjacent_qspecs + + def _annotate_shared_cluster(self, root_node: Node) -> None: + if ( + len(self._get_input_nodes_with_float_output(root_node)) == 0 + and len(self._get_user_nodes_with_float_input(root_node)) == 0 + ): + self.report_reject( + [root_node], + "No float inputs nor outputs to annotate", + ) + _mark_node_as_quantized( + root_node, + {}, + None, + is_quantized=True, + ) + return + + shared_nodes, adjacent_qspecs = self._get_shared_clique(root_node) + node_order = {node: index for index, node in enumerate(root_node.graph.nodes)} + ordered_nodes = sorted(shared_nodes, key=lambda node: node_order.get(node, 0)) + + if len(adjacent_qspecs) > 0: + if len(adjacent_qspecs) > 1: + logger.warning( + f"Multiple adjacent quantization specs found for {', '.join([n.name for n in ordered_nodes])}, all nodes will share the input quantization spec of {root_node.name}." + ) + + root_node_float_inputs = self._get_input_nodes_with_float_output(root_node) + if len(root_node_float_inputs) == 0: + self.report_reject( + ordered_nodes, + "Couldn't find any floating point input to base shared quantization spec on.", + ) + return + root_node_first_input = root_node_float_inputs[0] + + shared_qspec = SharedQuantizationSpec((root_node_first_input, root_node)) + for node in shared_nodes: + input_qspec_map: dict[Node, Optional[QuantizationSpec]] = { + n: shared_qspec # type: ignore[misc] + for n in self._get_input_nodes_with_float_output(node) + } + if len(self._get_user_nodes_with_float_input(node)) == 0: + output_qspec = None + else: + output_qspec = shared_qspec + _mark_node_as_quantized( + node, input_qspec_map, output_qspec, is_quantized=True + ) + + root_node.meta[Q_ANNOTATION_KEY].input_qspec_map[root_node_first_input] = ( + adjacent_qspecs[0] + ) + self.report_accept(ordered_nodes) + + else: + self.report_reject( + ordered_nodes, + "Couldn't find any adjacent quantization spec to base shared quantization spec on. You may however quantize these nodes manually if required.", + ) + return + + def annotate(self, model: torch.fx.GraphModule) -> None: # type: ignore[override] + for node in model.graph.nodes: + if node.target in self.targets and not self._is_annotated(node): + self._annotate_shared_cluster(node) + + def validate(self, model: torch.fx.GraphModule) -> bool: # type: ignore[override] + return True diff --git a/backends/arm/quantizer/quantization_config.py b/backends/arm/quantizer/quantization_config.py index 7e201644262..e6c53ebf966 100644 --- a/backends/arm/quantizer/quantization_config.py +++ b/backends/arm/quantizer/quantization_config.py @@ -13,13 +13,17 @@ from dataclasses import dataclass +from typing import Any, Callable, cast, Optional import torch +from torch.fx import Node from torchao.quantization.pt2e import ObserverOrFakeQuantize from torchao.quantization.pt2e.quantizer import ( DerivedQuantizationSpec, QuantizationSpec, + QuantizationSpecBase, + SharedQuantizationSpec, ) @@ -31,27 +35,29 @@ class QuantizationConfig: expose validated accessors. Attributes: - input_activation (QuantizationSpec | None): Spec for input activations. - output_activation (QuantizationSpec | None): Spec for output activations. - weight (QuantizationSpec | None): Spec for weights. - bias (QuantizationSpec | None): Spec for bias values. + input_activation (Optional[QuantizationSpec]): Spec for input activations. + output_activation (Optional[QuantizationSpec]): Spec for output activations. + weight (Optional[QuantizationSpec]): Spec for weights. + bias (Optional[QuantizationSpec]): Spec for bias values. """ - input_activation: QuantizationSpec | None - output_activation: QuantizationSpec | None - weight: QuantizationSpec | None - bias: QuantizationSpec | None + input_activation: Optional[QuantizationSpecBase] + output_activation: Optional[QuantizationSpecBase] + weight: Optional[QuantizationSpecBase] + bias: Optional[QuantizationSpecBase] | Callable[[Any], Any] - def get_input_act_qspec(self) -> QuantizationSpec | None: + def get_input_act_qspec( + self, node: Optional[Node] = None, input_node: Optional[Node] = None + ) -> Optional[QuantizationSpecBase]: """Get the validated input activation spec. Validate that the input activation qscheme is supported before returning the spec. Returns: - QuantizationSpec | None: Input activation spec, or ``None`` when - unset. + Optional[QuantizationSpecBase]: Input activation spec, or ``None`` when + unset. The ``node`` and ``input_node`` arguments are used by subclasses. Raises: ValueError: If the qscheme is not per-tensor affine or symmetric. @@ -60,7 +66,9 @@ def get_input_act_qspec(self) -> QuantizationSpec | None: if self.input_activation is None: return None # Validate that input_activation uses a supported qscheme - if self.input_activation.qscheme not in [ + if not hasattr( + self.input_activation, "qscheme" + ) or self.input_activation.qscheme not in [ torch.per_tensor_affine, torch.per_tensor_symmetric, ]: @@ -69,15 +77,18 @@ def get_input_act_qspec(self) -> QuantizationSpec | None: ) return self.input_activation - def get_output_act_qspec(self) -> QuantizationSpec | None: + def get_output_act_qspec( + self, node: Optional[Node] = None + ) -> Optional[QuantizationSpecBase]: """Get the validated output activation spec. Validate that the output activation qscheme is supported before returning the spec. Returns: - QuantizationSpec | None: Output activation spec, or ``None`` when - unset. + Optional[QuantizationSpecBase]: Output activation spec, or ``None`` when + unset. The ``node`` argument is currently unused and kept for + API parity. Raises: ValueError: If the qscheme is not per-tensor affine or symmetric. @@ -86,7 +97,9 @@ def get_output_act_qspec(self) -> QuantizationSpec | None: if self.output_activation is None: return None # Validate that output_activation uses a supported qscheme - if self.output_activation.qscheme not in [ + if not hasattr( + self.output_activation, "qscheme" + ) or self.output_activation.qscheme not in [ torch.per_tensor_affine, torch.per_tensor_symmetric, ]: @@ -95,14 +108,16 @@ def get_output_act_qspec(self) -> QuantizationSpec | None: ) return self.output_activation - def get_weight_qspec(self) -> QuantizationSpec | None: + def get_weight_qspec( + self, node: Optional[Node] = None + ) -> Optional[QuantizationSpecBase]: """Get the validated weight spec. Validate that the weight qscheme is supported (per-tensor or per-channel symmetric) before returning the spec. Returns: - QuantizationSpec | None: Weight spec, or ``None`` when unset. + Optional[QuantizationSpecBase]: Weight spec, or ``None`` when unset. Raises: ValueError: If the qscheme is not a supported symmetric scheme. @@ -111,25 +126,27 @@ def get_weight_qspec(self) -> QuantizationSpec | None: if self.weight is None: return None # Validate that weight uses a supported qscheme - if self.weight.qscheme not in [ + if not hasattr(self.weight, "qscheme") or self.weight.qscheme not in [ torch.per_tensor_symmetric, torch.per_channel_symmetric, ]: raise ValueError(f"Unsupported quantization_spec {self.weight} for weight") return self.weight - def get_bias_qspec(self, node: torch.fx.Node) -> QuantizationSpec | None: + def get_bias_qspec( + self, node: Optional[Node] = None + ) -> Optional[QuantizationSpecBase] | Callable[[Any], Any]: """Get the derived or validated bias spec. For conv/linear ops, derive bias qparams from the input/weight observers. Otherwise, validate a user-provided floating-point bias spec. Args: - node (torch.fx.Node): Node whose bias spec is requested. + node (Optional[Node]): Node whose bias spec is requested. Returns: - QuantizationSpec | None: Derived or provided bias spec, or ``None`` - when unset. + Optional[QuantizationSpecBase]: Derived or provided bias spec, or + ``None`` when unset. Raises: ValueError: If deriving qparams sees an unexpected number of @@ -138,6 +155,9 @@ def get_bias_qspec(self, node: torch.fx.Node) -> QuantizationSpec | None: """ + if self.bias is None or node is None: + return None + def _derive_qparams_fn( obs_or_fqs: list[ObserverOrFakeQuantize], ) -> tuple[torch.Tensor, torch.Tensor]: @@ -185,6 +205,12 @@ def _derive_qparams_fn( raise ValueError( "Input activation and weight QuantizationConfig must be specified." ) + if not isinstance( + self.input_activation, QuantizationSpec + ) or not isinstance(self.weight, QuantizationSpec): + raise ValueError( + "QuantizationConfig input_activation and weight must be instances of QuantizationSpec." + ) if (self.input_activation.dtype == self.weight.dtype == torch.int8) or ( self.input_activation.dtype == torch.int16 @@ -211,7 +237,7 @@ def _derive_qparams_fn( ch_axis = 0 quantization_spec = DerivedQuantizationSpec( - derived_from=[(input_act, node), (weight, node)], # type: ignore[list-item] + derived_from=((input_act, node), (weight, node)), # type: ignore[arg-type] derive_qparams_fn=_derive_qparams_fn, dtype=torch.int32, quant_min=torch.iinfo(torch.int32).min + 1, @@ -225,11 +251,95 @@ def _derive_qparams_fn( f"Bias quantization of types: i:{self.input_activation.dtype}, w:{self.weight.dtype} not implemented" ) - if self.bias is None: - return None - # Validate that bias dtype is floating-point - if self.bias.dtype != torch.float: - raise ValueError( - "Only float dtype for bias is supported for bias right now" - ) return self.bias + + +class TOSAQuantizationConfig(QuantizationConfig): + """Configures quantization, while enforcing TOSA specific constraints.""" + + SHARED_OUTPUT_ACT_QSPEC_PATTERNS = { + torch.ops.aten.adaptive_avg_pool2d.default, + torch.ops.aten.upsample_bilinear2d.vec, + torch.ops.aten.upsample_nearest2d.vec, + torch.ops.aten.avg_pool2d.default, + torch.ops.aten.max_pool2d.default, + torch.ops.aten.mean.default, + torch.ops.aten.mean.dim, + torch.ops.aten.silu.default, + torch.ops.aten.silu_.default, + } + + SHARED_INPUT_ACT_QSPEC_PATTERNS = { + torch.ops.aten.lt.Tensor, + torch.ops.aten.le.Tensor, + torch.ops.aten.gt.Tensor, + torch.ops.aten.ge.Tensor, + torch.ops.aten.eq.Tensor, + torch.ops.aten.ne.Tensor, + } + + def get_input_act_qspec(self, node=None, input_node=None): + """Return the configured input quantization spec. + + For comparison operators, make sure that both inputs share the same + quantization spec, by returning a SharedQuantizationSpec that ties the + quantization of both inputs together. For other operators, return the + default input activation spec. + + """ + if node is None or input_node is None: + return super().get_input_act_qspec(node, input_node) + + if node.target in self.SHARED_INPUT_ACT_QSPEC_PATTERNS: + if input_node == node.args[0]: + return super().get_input_act_qspec(node, input_node) + else: + return SharedQuantizationSpec((node.args[0], node)) + + return super().get_input_act_qspec(node, input_node) + + def get_weight_qspec( + self, node: Optional[Node] = None + ) -> Optional[QuantizationSpecBase]: + """Return the configured weight quantization spec. + + For conv transpose, return the per-channel quantization spec with + `ch_axis=1` to match the IOHW weight format used by TOSA, instead of + the default `ch_axis=0`. If no weight spec is configured, return + ``None``. + + """ + weight_qspec = super().get_weight_qspec() + if ( + node is not None + and weight_qspec is not None + and isinstance(weight_qspec, QuantizationSpec) + and weight_qspec.qscheme == torch.per_channel_symmetric + and node.target == torch.ops.aten.conv_transpose2d.input + ): + # MLETORCH-1853: Fix lazy import when moving files around + from executorch.backends.arm.quantizer.quantization_annotator import ( + _adjust_weight_qspec_for_conv_transpose, + ) + + weight_qspec = _adjust_weight_qspec_for_conv_transpose(node, weight_qspec) + + return weight_qspec + + def get_output_act_qspec( + self, node: Optional[Node] = None + ) -> Optional[QuantizationSpecBase]: + """Return the configured output activation quantization spec. + + If node is a pooling or upsample operator, returns a shared quantization spec. + If no weight spec is configured, return ``None``. + + """ + + if node is None: + return super().get_output_act_qspec() + if node.target not in self.SHARED_OUTPUT_ACT_QSPEC_PATTERNS: + return super().get_output_act_qspec() + if len(node.args) == 0: + return super().get_output_act_qspec() + return SharedQuantizationSpec((cast(Node, node.args[0]), node)) diff --git a/backends/arm/quantizer/quantizer_support.py b/backends/arm/quantizer/quantizer_support.py new file mode 100644 index 00000000000..bb3ea158fba --- /dev/null +++ b/backends/arm/quantizer/quantizer_support.py @@ -0,0 +1,179 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from itertools import product + +import torch +from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor +from executorch.backends.arm.quantizer.arm_quantizer_utils import PatternCheck +from executorch.backends.arm.quantizer.quantization_annotator import ( + _conv_ops, + _one_to_one, +) +from torch._ops import OpOverload + + +def combo_pattern(*pattern_lists): + "Returns the cartesian product of the given pattern lists." + return [tuple(pattern) for pattern in product(*pattern_lists)] + + +class ReluFusedPatternCheck(PatternCheck): + @classmethod + def check_quantization_config(cls, pattern, quantization_config): + if quantization_config is None: + return True + + output_node = pattern[-1] if pattern else None + output_qspec = quantization_config.get_output_act_qspec(output_node) + if output_qspec is None: + return False + + return output_qspec.qscheme not in ( + torch.per_tensor_symmetric, + torch.per_channel_symmetric, + ) + + +class ArithmeticFloatInputsCheck(PatternCheck): + @classmethod + def check_pattern(cls, pattern): + """For arithmetic ops all inputs must be quantizeable for quantization + to make sense. + """ + for node in pattern: + for input_node in node.all_input_nodes: + try: + tensor = get_first_fake_tensor(input_node) + except Exception: + return False + if not tensor.dtype.is_floating_point: + return False + + return True + + +BINARY_OP_PATTERNS = [ + (torch.ops.aten.add.Tensor,), + (torch.ops.aten.add_.Tensor,), + (torch.ops.aten.sub.Tensor,), + (torch.ops.aten.sub_.Tensor,), + (torch.ops.aten.matmul.default,), + (torch.ops.aten.mm.default,), + (torch.ops.aten.bmm.default,), + (torch.ops.aten.mul.Tensor,), + (torch.ops.aten.mul_.Tensor,), +] +ACTIVATION_FUNCTION_PATTERNS = [ + (torch.ops.aten.hardswish.default,), + (torch.ops.aten.hardswish_.default,), +] + +LINEAR_OPS = [torch.ops.aten.linear.default] +FUSED_ACTIVATION_OPS = [ + torch.ops.aten.relu.default, + torch.ops.aten.relu_.default, + torch.ops.aten.hardtanh.default, + torch.ops.aten.hardtanh_.default, + torch.ops.aten.hardsigmoid.default, + torch.ops.aten.hardsigmoid_.default, + torch.ops.aten.clamp.default, + torch.ops.aten.clamp_.default, +] +BATCH_NORM_OPS = [torch.ops.aten.batch_norm.default] +LINEAR_OP_PATTERNS = ( + combo_pattern(LINEAR_OPS) + + combo_pattern(LINEAR_OPS, FUSED_ACTIVATION_OPS) + + combo_pattern(LINEAR_OPS, BATCH_NORM_OPS) + + combo_pattern(LINEAR_OPS, BATCH_NORM_OPS, FUSED_ACTIVATION_OPS) +) +CONV_OP_PATTERNS = ( + combo_pattern(_conv_ops) + + combo_pattern(_conv_ops, FUSED_ACTIVATION_OPS) + + combo_pattern(_conv_ops, BATCH_NORM_OPS) + + combo_pattern(_conv_ops, BATCH_NORM_OPS, FUSED_ACTIVATION_OPS) +) +FUSED_RELU_OP_PATTERNS = ( + combo_pattern(LINEAR_OPS, FUSED_ACTIVATION_OPS) + + combo_pattern(LINEAR_OPS, BATCH_NORM_OPS, FUSED_ACTIVATION_OPS) + + combo_pattern(_conv_ops, FUSED_ACTIVATION_OPS) + + combo_pattern(_conv_ops, BATCH_NORM_OPS, FUSED_ACTIVATION_OPS) +) + +ALL_QPARAM_OP_PATTERNS = ( + [(target,) for target in _one_to_one] + + ACTIVATION_FUNCTION_PATTERNS + + CONV_OP_PATTERNS + + LINEAR_OP_PATTERNS + + BINARY_OP_PATTERNS + + [ + (torch.ops.aten.full.default,), + (torch.ops.aten.full,), + (torch.ops.aten.zeros.default,), + (torch.ops.aten.ones.default,), + (torch.ops.aten.fill_.Scalar,), + (torch.ops.aten.scalar_tensor.default,), + (torch.ops.aten.zeros_like.default,), + (torch.ops.aten._softmax.default,), + (torch.ops.aten.softmax.int,), + (torch.ops.aten.div.Tensor,), + (torch.ops.aten.div_.Tensor,), + (torch.ops.aten.div.Tensor_mode,), + (torch.ops.aten.floor,), + (torch.ops.aten.floor_divide.default,), + (torch.ops.aten.logit.default,), + (torch.ops.aten.glu.default,), + (torch.ops.aten.addmm.default,), + (torch.ops.aten.layer_norm.default,), + (torch.ops.aten.group_norm.default,), + (torch.ops.aten.sqrt.default,), + (torch.ops.aten.silu.default,), + (torch.ops.aten.silu_.default,), + (torch.ops.aten.var.dim,), + (torch.ops.aten.var.correction,), + (torch.ops.aten.leaky_relu.default,), + (torch.ops.aten.leaky_relu_.default,), + (torch.ops.aten.linalg_vector_norm.default,), + (torch.ops.aten.log_softmax.int,), + (torch.ops.aten.round.default,), + (torch.ops.aten.arange.start_step,), + (torch.ops.aten.embedding.default,), + (torch.ops.aten.adaptive_avg_pool2d.default,), + (torch.ops.aten.upsample_bilinear2d.vec,), + (torch.ops.aten.upsample_nearest2d.vec,), + (torch.ops.aten.avg_pool2d.default,), + (torch.ops.aten.max_pool2d.default,), + (torch.ops.aten.cosine_similarity.default,), + (torch.ops.aten.sigmoid.default,), + (torch.ops.aten.remainder.Tensor,), + (torch.ops.aten.remainder.Scalar,), + (torch.ops.aten.mean.dim,), + (torch.ops.aten.mean.default,), + (torch.ops.aten.neg.default,), + (torch.ops.aten.scaled_dot_product_attention.default,), + (torch.ops.aten.abs.default,), + (torch.ops.aten.minimum.default,), + (torch.ops.aten.maximum.default,), + (torch.ops.aten.lt.Tensor,), + (torch.ops.aten.le.Tensor,), + (torch.ops.aten.gt.Tensor,), + (torch.ops.aten.ge.Tensor,), + (torch.ops.aten.eq.Tensor,), + (torch.ops.aten.ne.Tensor,), + (torch.ops.aten.lt.Scalar,), + (torch.ops.aten.le.Scalar,), + (torch.ops.aten.gt.Scalar,), + (torch.ops.aten.ge.Scalar,), + (torch.ops.aten.eq.Scalar,), + (torch.ops.aten.ne.Scalar,), + ] +) +TOSA_QUANTIZER_SUPPORT_DICT: dict[tuple[OpOverload, ...], type[PatternCheck] | None] = { + pattern: None for pattern in ALL_QPARAM_OP_PATTERNS +} +for pattern in FUSED_RELU_OP_PATTERNS: + TOSA_QUANTIZER_SUPPORT_DICT[pattern] = ReluFusedPatternCheck +for pattern in BINARY_OP_PATTERNS: + TOSA_QUANTIZER_SUPPORT_DICT[pattern] = ArithmeticFloatInputsCheck diff --git a/backends/arm/test/models/test_mobilenet_v2_arm.py b/backends/arm/test/models/test_mobilenet_v2_arm.py index d5f6515d820..fb624f3eb3f 100644 --- a/backends/arm/test/models/test_mobilenet_v2_arm.py +++ b/backends/arm/test/models/test_mobilenet_v2_arm.py @@ -11,6 +11,7 @@ import torch from executorch.backends.arm.quantizer import get_symmetric_quantization_config +from executorch.backends.arm.quantizer.arm_quantizer import _TOSAQuantizerV2 from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import ( EthosU55PipelineINT, @@ -119,6 +120,29 @@ def test_mv2_u55_INT(per_channel_quantization): pipeline.run() +@pytest.mark.slow +@common.XfailIfNoCorstone300 +@common.parametrize("per_channel_quantization", quant_test_data) +def test_mv2_u55_INT_composable_quantizer(per_channel_quantization): + pipeline = EthosU55PipelineINT[input_t]( + mv2, + model_inputs, + aten_ops=[], + exir_ops=[], + use_to_edge_transform_and_lower=True, + per_channel_quantization=per_channel_quantization, + atol=0.25, + qtol=1, + ) + + # Create composable_quantizer and force the pipeline to use it instead of the default quantizer + composable_quantizer = _TOSAQuantizerV2(pipeline.tester.compile_spec) + qconfig = get_symmetric_quantization_config(is_per_channel=per_channel_quantization) + composable_quantizer.set_global(qconfig) + pipeline.quantizer.quantizer = composable_quantizer + pipeline.run() + + @pytest.mark.slow @common.XfailIfNoCorstone320 @common.parametrize("per_channel_quantization", quant_test_data) diff --git a/backends/arm/test/quantizer/test_set_module_name.py b/backends/arm/test/quantizer/test_set_module_name.py index d0ca781256f..6ca7e3f970c 100644 --- a/backends/arm/test/quantizer/test_set_module_name.py +++ b/backends/arm/test/quantizer/test_set_module_name.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -113,13 +113,13 @@ def validate_node( if len(node.all_input_nodes) == 3: input_node, weight_node, bias_node = node.all_input_nodes bias_qspec = quantization_config.get_bias_qspec(node) - validate_input(bias_node, bias_qspec) + validate_input(bias_node, bias_qspec) # type: ignore[arg-type] else: input_node, weight_node = node.all_input_nodes - validate_input(input_node, input_qspec) - validate_input(weight_node, weight_qspec) - validate_output(node, output_qspec) + validate_input(input_node, input_qspec) # type: ignore[arg-type] + validate_input(weight_node, weight_qspec) # type: ignore[arg-type] + validate_output(node, output_qspec) # type: ignore[arg-type] def test_set_module_name_tosa_INT() -> None: diff --git a/backends/cortex_m/quantizer/TARGETS b/backends/cortex_m/quantizer/TARGETS index 0af105efef0..7a4ef5cd78a 100644 --- a/backends/cortex_m/quantizer/TARGETS +++ b/backends/cortex_m/quantizer/TARGETS @@ -12,17 +12,27 @@ python_library( name = "quantizer", srcs = [ "__init__.py", - "operator_configs.py", + "node_finders.py", + "pattern_checkers.py", + "pattern_matcher.py", "quantization_configs.py", "quantizer.py", + "quantizer_reporter.py", + "quantizer_support.py", ], deps = [ "//caffe2:torch", + "//executorch/backends/arm:common", + "//executorch/backends/arm:constants", + "//executorch/backends/arm/quantizer:arm_quantizer_utils", + "//executorch/backends/arm/quantizer:quantization_annotator", "//executorch/backends/arm/quantizer:quantization_config", "//pytorch/ao:torchao", + "fbsource//third-party/pypi/tabulate:tabulate", ], ) + python_library( name = "quantization_configs", srcs = [ @@ -30,7 +40,21 @@ python_library( ], deps = [ "//caffe2:torch", + "//executorch/backends/arm/quantizer:arm_quantizer_utils", "//executorch/backends/arm/quantizer:quantization_config", "//pytorch/ao:torchao", + ":quantizer_reporter", + ], +) + +python_library( + name = "quantizer_reporter", + srcs = [ + "quantizer_reporter.py", + ], + deps = [ + "//caffe2:torch", + "//pytorch/ao:torchao", + "fbsource//third-party/pypi/tabulate:tabulate", ], ) diff --git a/backends/cortex_m/quantizer/node_finders.py b/backends/cortex_m/quantizer/node_finders.py index 427468a0ad4..63b08da68d8 100644 --- a/backends/cortex_m/quantizer/node_finders.py +++ b/backends/cortex_m/quantizer/node_finders.py @@ -3,9 +3,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from abc import ABC, abstractmethod from typing import Callable, Iterator, List +from executorch.backends.arm.quantizer.arm_quantizer_utils import NodeFinder from torch._ops import OpOverload from torch.fx import GraphModule, Node from torchao.quantization.pt2e.quantizer.utils import get_module_name_filter @@ -23,17 +23,6 @@ def format_items(items) -> str: return ", ".join(str(item) for item in items) -class NodeFinder(ABC): - @abstractmethod - def find_nodes(self, model: GraphModule) -> Iterator[Node]: - """Return nodes of the graph module depending on NodeFinder type. - - Args: - model (GraphModule): The graph module to search for matching nodes. - """ - pass - - class GlobalNodeFinder(NodeFinder): """ Finds all nodes of the graph. diff --git a/backends/cortex_m/quantizer/pattern_checkers.py b/backends/cortex_m/quantizer/pattern_checkers.py index 0d4d5f3f101..04b46e159d1 100644 --- a/backends/cortex_m/quantizer/pattern_checkers.py +++ b/backends/cortex_m/quantizer/pattern_checkers.py @@ -7,6 +7,7 @@ import torch from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor +from executorch.backends.arm.quantizer.arm_quantizer_utils import PatternCheck from executorch.backends.cortex_m.passes.passes_utils import ( coerce_int_pair, is_channel_broadcast, @@ -18,73 +19,7 @@ CortexMQuantizationConfig, ) from torch.fx import Node -from torchao.quantization.pt2e.quantizer import ( - QuantizationSpecBase, - SharedQuantizationSpec, -) - - -class PatternCheck: - """ - Base class for pattern checks. - - PatternChecks are used to define which which patterns are supported for quantization. - For example, ADD in the Cortex-M backend does not support general broadcasting, so - a PatternCheck can be used to filter out such patterns. They also only support per - tensor quantization, so the PatternCheck filters out quantization configs that use - per channel quantization. - """ - - @classmethod - def is_per_tensor(cls, qspec: QuantizationSpecBase | None) -> bool: - """ - Returns true if the given quantization spec is per-tensor, otherwise false. - """ - if not isinstance(qspec, QuantizationSpecBase): - return False - return qspec.qscheme in (torch.per_tensor_affine, torch.per_tensor_symmetric) - - @classmethod - def is_per_channel(cls, qspec: QuantizationSpecBase | None) -> bool: - """ - Returns true if the given quantization spec is per-channel, otherwise false. - """ - if not isinstance(qspec, QuantizationSpecBase): - return False - return qspec.qscheme in (torch.per_channel_affine, torch.per_channel_symmetric) - - @classmethod - def is_int8_activations( - cls, qconfig: CortexMQuantizationConfig, output_node: Node | None = None - ) -> bool: - """ - Returns true if the given quantization spec uses int8 quantization, otherwise false. - - Output node is required for determining output quantization spec for some ops, otherwise it can be left as None. - """ - input_qspec = qconfig.get_input_act_qspec() - output_qspec = qconfig.get_output_act_qspec(output_node) - if not isinstance(input_qspec, QuantizationSpecBase) or not isinstance( - output_qspec, QuantizationSpecBase - ): - return False - return input_qspec.dtype == torch.int8 and output_qspec.dtype == torch.int8 - - @classmethod - def check_pattern(cls, pattern: list[Node]) -> bool: - """ - Returns true if the given pattern is supported, otherwise false. - """ - return True - - @classmethod - def check_quantization_config( - cls, pattern: list[Node], quantization_config: CortexMQuantizationConfig - ) -> bool: - """ - Returns true if the given quantization config is supported for a given node pattern, otherwise false. - """ - return True +from torchao.quantization.pt2e.quantizer import SharedQuantizationSpec class CortexMAddMulCheck(PatternCheck): diff --git a/backends/cortex_m/quantizer/pattern_matcher.py b/backends/cortex_m/quantizer/pattern_matcher.py index 7ed848d5e83..1123636bad9 100644 --- a/backends/cortex_m/quantizer/pattern_matcher.py +++ b/backends/cortex_m/quantizer/pattern_matcher.py @@ -7,13 +7,10 @@ from dataclasses import dataclass from typing import Iterator, List, Optional -from executorch.backends.arm.quantizer.quantization_annotator import _is_large_scalar +from executorch.backends.arm.quantizer.arm_quantizer_utils import PatternCheck -from executorch.backends.cortex_m.quantizer.pattern_checkers import PatternCheck -from executorch.backends.cortex_m.quantizer.quantization_configs import ( - CortexMQuantizationConfig, - QuantizationConfig, -) +from executorch.backends.arm.quantizer.quantization_annotator import _is_large_scalar +from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig from torch._ops import OpOverload from torch.fx import Node @@ -26,13 +23,13 @@ class PatternMatchResult: class PatternMatcher: - """ - Find supported patterns in a sequence of nodes. + """Find supported patterns in a sequence of nodes. Attributes: support_dict: A dictionary mapping patterns (tuples of operator overloads) to PatternCheck instances that validate the patterns. support_dict_name: An optional name for the support dict, used for logging. + """ Q_PATTERN_MATCHED_KEY = "quantizer_matched" @@ -45,7 +42,7 @@ class PatternMatcher: def __init__( self, - support_dict: dict[tuple[OpOverload, ...], PatternCheck], + support_dict: dict[tuple[OpOverload, ...], Optional[type[PatternCheck]]], support_dict_name: str | None = None, ): self.support_dict = support_dict @@ -58,11 +55,13 @@ def __init__( def _validate_match( self, match: List[Node], - quantization_config: CortexMQuantizationConfig, + quantization_config: QuantizationConfig, ) -> Optional[PatternMatchResult]: - """ - Returns a PatternMatchResult when the pattern structurally matches, with - status indicating accept/reject. Returns None if there is no match. + """Returns a PatternMatchResult when the pattern structurally matches, + with status indicating accept/reject. + + Returns None if there is no match. + """ # Reject match if it contains a node that has already been matched as part of another pattern. @@ -98,8 +97,8 @@ def _validate_match( return PatternMatchResult(match, True) def _get_match(self, node_queue: List[Node]) -> List[Node]: - """ - Returns the longest pattern match starting at the front of the queue. + """Returns the longest pattern match starting at the front of the + queue. """ if node_queue[0].op in ("placeholder", "output"): return [node_queue[0]] @@ -116,8 +115,8 @@ def _get_match(self, node_queue: List[Node]) -> List[Node]: def _get_matches( self, node_queue: List[Node], quantization_config: QuantizationConfig ) -> List[PatternMatchResult]: - """ - Returns the longest accepted match starting at the first node of the queue as well as longer rejected matches. + """Returns the longest accepted match starting at the first node of the + queue as well as longer rejected matches. """ matches = [] accepted = False @@ -139,8 +138,13 @@ def _get_matches( def _dequeue_and_get_matches( self, node_queue: List[Node], quantization_config: QuantizationConfig ) -> List[PatternMatchResult]: - """ - Dequeues the longest accepted match starting at the first node of the queue, and returns all potential matches that were checked (rejected ones). If no match is found, simply dequeues the first node and returns an empty list. + """Dequeues the longest accepted match starting at the first node of the + queue, and returns all potential matches that were checked (rejected + ones). + + If no match is found, simply dequeues the first node and returns an + empty list. + """ potential_matches = self._get_matches(node_queue, quantization_config) accepted_matches = [m for m in potential_matches if m.accepted] @@ -156,16 +160,16 @@ def _dequeue_and_get_matches( return potential_matches def find_pattern_matches( - self, nodes: Iterator[Node], quantization_config: CortexMQuantizationConfig + self, nodes: Iterator[Node], quantization_config: QuantizationConfig ) -> Iterator[PatternMatchResult]: - """ - Match all given patterns in the graph and return match results with - acceptance/rejection status. - Each node can only be part of one match, larger patterns are prioritized. - Currently only linear patterns (single chain) are supported. + """Match all given patterns in the graph and return match results with + acceptance/rejection status. Each node can only be part of one match, + larger patterns are prioritized. Currently only linear patterns (single + chain) are supported. + + Q_PATTERN_MATCHED_KEY is set to True in node.meta to track which nodes + have already been matched. - Q_PATTERN_MATCHED_KEY is set to True in node.meta to track which nodes have - already been matched. """ node = next(nodes, None) diff --git a/backends/cortex_m/quantizer/quantization_configs.py b/backends/cortex_m/quantizer/quantization_configs.py index 030c0fa4d93..5261a5c2f88 100644 --- a/backends/cortex_m/quantizer/quantization_configs.py +++ b/backends/cortex_m/quantizer/quantization_configs.py @@ -5,7 +5,15 @@ import torch +from executorch.backends.arm.quantizer.arm_quantizer_utils import ( + _get_int32_bias_qspec, + _get_int32_per_channel_bias_qspec, +) from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig +from executorch.backends.cortex_m.quantizer.quantizer_reporter import ( + SUPPORTED_QCONFIGS, + SUPPORTED_QSPECS, +) from torch.fx import Node from torchao.quantization.pt2e import ( HistogramObserver, @@ -13,7 +21,6 @@ PerChannelMinMaxObserver, ) from torchao.quantization.pt2e.quantizer import ( - DerivedQuantizationSpec, FixedQParamsQuantizationSpec, QuantizationSpec, SharedQuantizationSpec, @@ -86,7 +93,9 @@ class CortexMQuantizationConfig(QuantizationConfig): """Configures quantization, while enforcing cortex-m specific constraints.""" - def get_input_act_qspec(self, node: Node | None = None) -> QuantizationSpec | None: + def get_input_act_qspec( + self, node: Node | None = None, input_node: Node | None = None + ) -> QuantizationSpec | None: """ Returns the configured input activation spec, no specific adjustments. """ @@ -132,44 +141,6 @@ def get_bias_qspec(self, node: Node) -> QuantizationSpec | None: return super().get_bias_qspec(node) -def _derive_bias_qparams_fn( - obs_or_fqs, -) -> tuple[torch.Tensor, torch.Tensor]: - if len(obs_or_fqs) != 2: - raise ValueError( - f"Expecting two obs/fqs, one for activation and one for weight, got: {len(obs_or_fqs)}" - ) - act_obs_or_fq = obs_or_fqs[0] - weight_obs_or_fq = obs_or_fqs[1] - act_scale, _ = act_obs_or_fq.calculate_qparams() - weight_scale, _ = weight_obs_or_fq.calculate_qparams() - return act_scale * weight_scale, torch.full_like( - weight_scale, fill_value=0, dtype=torch.int32 - ) - - -def _get_int32_bias_qspec(node): - return DerivedQuantizationSpec( - derived_from=((node.args[0], node), (node.args[1], node)), # type: ignore[list-item] - derive_qparams_fn=_derive_bias_qparams_fn, - dtype=torch.int32, - quant_min=torch.iinfo(torch.int32).min, - quant_max=torch.iinfo(torch.int32).max - 1, - ) - - -def _get_int32_per_channel_bias_qspec(node): - return DerivedQuantizationSpec( - derived_from=((node.args[0], node), (node.args[1], node)), # type: ignore[list-item] - derive_qparams_fn=_derive_bias_qparams_fn, - dtype=torch.int32, - quant_min=torch.iinfo(torch.int32).min, - quant_max=torch.iinfo(torch.int32).max - 1, - qscheme=torch.per_channel_symmetric, - ch_axis=0, - ) - - # ----------------- QUANTIZATION CONFIG PRESETS ----------------- INT8_PER_TENSOR_CONFIG = CortexMQuantizationConfig( INT8_ACTIVATION_PER_TENSOR_QSPEC, @@ -185,3 +156,24 @@ def _get_int32_per_channel_bias_qspec(node): INT8_WEIGHT_PER_CHANNEL_QSPEC, _get_int32_per_channel_bias_qspec, ) + + +# Register supported quantization configs and qspecs in the reporter for human-readable reporting +# MLETORCH-1854: Temporary solution, refactor to automatically register these instead +SUPPORTED_QCONFIGS.update( + { + INT8_PER_CHANNEL_CONFIG: f"{__name__}.INT8_PER_CHANNEL_QCONFIG", + INT8_PER_TENSOR_CONFIG: f"{__name__}.INT8_PER_TENSOR_QCONFIG", + } +) + +SUPPORTED_QSPECS.update( + { + INT8_ACTIVATION_PER_TENSOR_QSPEC: "INT8_ACTIVATION_PER_TENSOR_QSPEC", + INT8_ACTIVATION_PER_CHANNEL_QSPEC: "INT8_ACTIVATION_PER_CHANNEL_QSPEC", + INT8_WEIGHT_PER_TENSOR_QSPEC: "INT8_WEIGHT_PER_TENSOR_QSPEC", + INT8_WEIGHT_PER_CHANNEL_QSPEC: "INT8_WEIGHT_PER_CHANNEL_QSPEC", + INT8_WEIGHT_PER_CHANNEL_TRANSPOSE_QSPEC: "INT8_WEIGHT_PER_CHANNEL_TRANSPOSE_QSPEC", + SOFTMAX_OUTPUT_FIXED_QSPEC: "SOFTMAX_OUTPUT_FIXED_QSPEC", + } +) diff --git a/backends/cortex_m/quantizer/quantizer.py b/backends/cortex_m/quantizer/quantizer.py index 1e8aa1da47d..a024bd035f4 100644 --- a/backends/cortex_m/quantizer/quantizer.py +++ b/backends/cortex_m/quantizer/quantizer.py @@ -4,73 +4,41 @@ # LICENSE file in the root directory of this source tree. -import logging -import operator -from typing import List, Optional +from typing import List -import torch -from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo - -from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig +from executorch.backends.arm.quantizer.arm_quantizer_utils import ( + _mark_node_as_quantized, + PatternQuantizer, + SharedQspecQuantizer, +) from executorch.backends.cortex_m.passes.cortex_m_pass_manager import CortexMPassManager from executorch.backends.cortex_m.quantizer.node_finders import ( GlobalNodeFinder, - NodeFinder, NodeTargetNodeFinder, ) from executorch.backends.cortex_m.quantizer.pattern_matcher import PatternMatcher from executorch.backends.cortex_m.quantizer.quantization_configs import ( INT8_PER_CHANNEL_CONFIG, INT8_PER_TENSOR_CONFIG, - QuantizationSpec, -) -from executorch.backends.cortex_m.quantizer.quantizer_reporter import ( - QuantizerInfo, - QuantizerReporter, - QuantizerReporterUser, - SUPPORTED_QCONFIGS, ) +from executorch.backends.cortex_m.quantizer.quantizer_reporter import QuantizerReporter from executorch.backends.cortex_m.quantizer.quantizer_support import ( __name__ as cortex_m_quantizer_support_module, CONV_OP_PATTERNS, CONV_TRANSPOSE_OP_PATTERNS, CORTEX_M_QUANTIZER_SUPPORT_DICT, ) -from torch._ops import OpOverload -from torch.fx import GraphModule, Node -from torchao.quantization.pt2e.quantizer import ( - ComposableQuantizer, - QuantizationAnnotation, - Quantizer, - SharedQuantizationSpec, -) -from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY - -logger = logging.getLogger(__name__) - - -def has_float_output(node: Node) -> bool: - meta_val = node.meta.get("val", None) - if isinstance(meta_val, torch.Tensor): - return meta_val.dtype.is_floating_point - - return False +from torch.fx import GraphModule +from torchao.quantization.pt2e.quantizer import ComposableQuantizer, Quantizer def mark_node_as_annotated( - node: Node, - input_qspec_map: dict[Node, Optional[QuantizationSpec]], - output_qspec: Optional[QuantizationSpec], - reporter: Optional[QuantizerReporter] = None, - quantizer: Optional[Quantizer] = None, + node, + input_qspec_map, + output_qspec, + is_quantized, ) -> None: - node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(input_qspec_map, output_qspec) - annotation_info = ArmAnnotationInfo( - quantized=True, - ) - meta_custom = node.meta.get("custom", {}) - meta_custom[ArmAnnotationInfo.CUSTOM_META_KEY] = dict(annotation_info) - node.meta["custom"] = meta_custom + _mark_node_as_quantized(node, input_qspec_map, output_qspec, is_quantized) class CortexMQuantizer(ComposableQuantizer): @@ -114,399 +82,3 @@ def validate(self, model: GraphModule) -> bool: def transform_for_annotation(self, model: GraphModule) -> GraphModule: pass_manager = CortexMPassManager(None) return pass_manager.transform_for_annotation(model) - - -class PatternQuantizer(Quantizer, QuantizerReporterUser): - """ - Quantizes a graph according to an OperatorConfig. - - Args: - quantization_config (QuantizationConfig): The quantization config to use for annotation. - node_finder (NodeFinder): The node finder to use for finding nodes to match patterns. - pattern_matcher (PatternMatcher): The pattern matcher to use for finding patterns in the nodes. - """ - - def __init__( - self, - quantization_config: QuantizationConfig, - node_finder: NodeFinder, - pattern_matcher: PatternMatcher, - ) -> None: - super().__init__() - self.quantization_config: QuantizationConfig = quantization_config - self.node_finder: NodeFinder = node_finder - self.pattern_matcher: PatternMatcher = pattern_matcher - - def get_quantizer_info(self): - name = self.__class__.__name__ - targeted_nodes_description = str(self.node_finder) - quantization_config_path = SUPPORTED_QCONFIGS.get( - self.quantization_config, "CUSTOM_QCONFIG" - ) - support_config_path = self.pattern_matcher.support_dict_name - - return QuantizerInfo( - name, - targeted_nodes_description, - quantization_config_path, - support_config_path, - ) - - def is_parameter(self, node: Node, model: GraphModule) -> bool: - """Returns True if the given node is a parameter of the model.""" - try: - _ = model.get_parameter(node.target) - return True - except Exception: - return False - - def is_weight(self, node: Node, params: List[Node], model: GraphModule) -> bool: - """Returns True if node is the first parameter of the given parameters""" - return len(params) > 0 and node == params[0] - - def is_bias(self, node: Node, params: List[Node], model: GraphModule) -> bool: - """Returns True if node is the second parameter of the given parameters""" - return len(params) == 2 and node == params[1] - - def annotate_match( - self, match: List[Node], config: QuantizationConfig | None, model: GraphModule - ) -> None: - """ - Annotates a matched pattern according to the given quantization config. The - following assumptions are made: - - - All operators have either no parameters, only weights, or weights and biases - - Tensors which are the first parameter of an operator are annotated as weights - - Tensors which are the second parameter of an operator are annotated as biases - - All other tensors going into the matched pattern are annotated as input activations. - - All other outputs coming out of the matched pattern are annotated as output activations. - - """ - for node in match: - input_qspec_map = {} - output_qspec = None - - params = [n for n in node.all_input_nodes if self.is_parameter(n, model)] - # Check that the assumptions on number of parameters hold to avoid silent errors - assert ( - 0 <= len(params) <= 2 - ), f"{self.__class__.__name__} expected 0 params, 1 params (weight) or 2 params (weight, bias), but got {len(params)} for node {node}." - - for input_node in node.all_input_nodes: - # Observers only work on floating point tensors, so make sure to skip other dtypes - if not has_float_output(input_node): - continue - if self.is_weight(input_node, params, model): - input_qspec_map[input_node] = ( - config.get_weight_qspec(node) if config else None - ) - elif self.is_bias(input_node, params, model): - # Bias qspec is derived from input + weight qspecs - input_qspec_map[input_node] = ( - config.get_bias_qspec(node) if config else None - ) - elif input_node not in match: - input_qspec_map[input_node] = ( - config.get_input_act_qspec() if config else None - ) - - if all(node not in match for node in node.users) and output_qspec is None: - output_qspec = config.get_output_act_qspec(node) if config else None - - mark_node_as_annotated( - node, input_qspec_map, output_qspec, self.reporter, self - ) - - def annotate(self, model: GraphModule) -> None: - nodes = self.node_finder.find_nodes(model) - matches = self.pattern_matcher.find_pattern_matches( - nodes, self.quantization_config - ) - for result in matches: - if result.accepted: - self.annotate_match(result.pattern, self.quantization_config, model) - self.report_accept(result.pattern) - else: - self.report_reject( - result.pattern, - result.message or "Pattern rejected.", - ) - - def validate(self, model: GraphModule) -> bool: - return True - - -class SharedQspecQuantizer(Quantizer, QuantizerReporterUser): - """ - Special quantizer for assuring that given ops share the same quantization parameters on all input and outputs, - i.e. ops which does not change the scale such as clone, min/max, transposes and so on. - - Args: - targets (Optional[List[OpOverload]]): List of operator overloads to apply shared quantization spec to. - If None, a default list of supported ops is used. - """ - - SHARED_QSPEC_OPS_DEFAULT: List[OpOverload] = [ - # Clone - torch.ops.aten.clone.default, - torch.ops.aten.lift_fresh_copy.default, - torch.ops.aten.detach_.default, - torch.ops.aten.alias.default, - torch.ops.aten.alias_copy.default, - torch.ops.aten.copy_.default, - torch.ops.aten.detach_copy.default, - torch.ops.aten.unfold_copy.default, - torch.ops.aten.unbind.int, - # Min/Max/Mean - torch.ops.aten.minimum.default, - torch.ops.aten.maximum.default, - torch.ops.aten.min.dim, - torch.ops.aten.max.dim, - torch.ops.aten.amin.default, - torch.ops.aten.amax.default, - # Data shuffling - torch.ops.aten.permute.default, - torch.ops.aten.permute_copy.default, - torch.ops.aten.transpose.int, - torch.ops.aten.transpose_copy.int, - torch.ops.aten.t_copy.default, - torch.ops.aten.t.default, - torch.ops.aten.repeat.default, - torch.ops.aten.repeat_interleave.self_int, - torch.ops.aten.expand_copy.default, - torch.ops.aten.expand.default, - torch.ops.aten.select.int, - torch.ops.aten.select_copy.int, - torch.ops.aten.slice.Tensor, - torch.ops.aten.slice_copy.Tensor, - torch.ops.aten.split.Tensor, - torch.ops.aten.split_with_sizes.default, - torch.ops.aten.split_copy.Tensor, - torch.ops.aten.tile.default, - torch.ops.aten.flip.default, - torch.ops.aten.index_select.default, - torch.ops.aten.index_put.default, - torch.ops.aten.contiguous.default, - torch.ops.aten.as_strided_copy.default, - torch.ops.aten.pixel_shuffle.default, - torch.ops.aten.pixel_unshuffle.default, - torch.ops.aten.cat.default, - torch.ops.aten.concatenate.default, - torch.ops.aten.stack.default, - torch.ops.aten.dropout.default, - torch.ops.aten.dropout_.default, - torch.ops.aten.chunk.default, - torch.ops.aten.index.Tensor, - torch.ops.aten.gather.default, - operator.getitem, - # Change shape - torch.ops.aten.squeeze.default, - torch.ops.aten.squeeze_copy.default, - torch.ops.aten.squeeze_copy.dim, - torch.ops.aten.squeeze.dim, - torch.ops.aten.squeeze.dims, - torch.ops.aten.squeeze_.dim, - torch.ops.aten.unsqueeze.default, - torch.ops.aten.unsqueeze_copy.default, - torch.ops.aten.reshape.default, - torch.ops.aten.view.default, - torch.ops.aten.view_as.default, - torch.ops.aten.view_copy.default, - torch.ops.aten._unsafe_view.default, - torch.ops.aten.unflatten.int, - torch.ops.aten.flatten.using_ints, - # Padding - torch.ops.aten.pad.default, - torch.ops.aten.constant_pad_nd.default, - # Ativation functions - torch.ops.aten.clamp.default, - torch.ops.aten.clamp.Tensor, - torch.ops.aten.hardtanh.default, - torch.ops.aten.hardtanh_.default, - torch.ops.aten.relu.default, - torch.ops.aten.relu_.default, - # Logic ops - torch.ops.aten.eq.Tensor, - torch.ops.aten.eq.Scalar, - torch.ops.aten.ne.Tensor, - torch.ops.aten.ne.Scalar, - torch.ops.aten.ge.Tensor, - torch.ops.aten.ge.Scalar, - torch.ops.aten.gt.Tensor, - torch.ops.aten.gt.Scalar, - torch.ops.aten.le.Tensor, - torch.ops.aten.le.Scalar, - torch.ops.aten.lt.Tensor, - torch.ops.aten.lt.Scalar, - torch.ops.aten.where.self, - torch.ops.aten.where.default, - torch.ops.higher_order.while_loop, - torch.ops.higher_order.cond, - ] - - def __init__(self, targets: Optional[List[OpOverload]] = None) -> None: - super().__init__() - if targets is None: - self.targets = self.SHARED_QSPEC_OPS_DEFAULT - self.support_config_path = ( - __name__ + f".{self.__class__.__name__}.SHARED_QSPEC_OPS_DEFAULT" - ) - else: - self.targets = targets - self.support_config_path = ( - f"CUSTOM TARGETS: {', '.join([str(target) for target in targets])}" - ) - - def get_quantizer_info(self): - name = self.__class__.__name__ - targeted_nodes_description = "" - quantization_config_path = "SHARED_QCONFIG" - support_config_path = self.support_config_path - return QuantizerInfo( - name, - targeted_nodes_description, - quantization_config_path, - support_config_path, - ) - - def _is_annotated(self, node: Node) -> bool: - return Q_ANNOTATION_KEY in node.meta - - def _get_input_nodes_with_float_output(self, node: Node) -> List[Node]: - # Observers only work on floating point tensors, so make sure to skip other dtypes - return [n for n in node.all_input_nodes if has_float_output(n)] - - def _get_user_nodes_with_float_input(self, node: Node) -> List[Node]: - # Observers only work on floating point tensors, so make sure to skip other dtypes - return [n for n in node.users.keys() if has_float_output(node)] - - def _get_shared_clique(self, root_node: Node) -> set[Node]: - """ - Finds a cluster of nodes with targets in self.targets, starting in root_node. - """ - shared_nodes = set() - bfs_queue = [root_node] - adjacent_qspecs = [] - - while bfs_queue: - node = bfs_queue.pop(0) - shared_nodes.add(node) - - # Neighbours may either be other shared nodes, annotated nodes, or non-annotated (float) nodes. - for input_node in node.all_input_nodes: - if input_node.target in self.targets and input_node not in shared_nodes: - if not self._is_annotated(input_node): - bfs_queue.append(input_node) - if self._is_annotated(input_node): - output_qspec = input_node.meta.get(Q_ANNOTATION_KEY).output_qspec - if output_qspec is not None: - adjacent_qspecs.append(output_qspec) - - for output_node in node.users.keys(): - if ( - output_node.target in self.targets - and output_node not in shared_nodes - ): - if not self._is_annotated(output_node): - bfs_queue.append(output_node) - if ( - self._is_annotated(output_node) - and node in output_node.meta.get(Q_ANNOTATION_KEY).input_qspec_map - ): - input_qspec = output_node.meta.get( - Q_ANNOTATION_KEY - ).input_qspec_map[node] - if input_qspec is not None: - adjacent_qspecs.append(input_qspec) - - return shared_nodes, adjacent_qspecs - - def _annotate_shared_cluster(self, root_node: Node) -> None: - """ - Finds a cluster of unannotated nodes starting in root_node and annotates them with a common - SharedQuantizationSpec. - """ - - if ( - len(self._get_input_nodes_with_float_output(root_node)) == 0 - and len(self._get_user_nodes_with_float_input(root_node)) == 0 - ): - self.report_reject( - [root_node], - "No float inputs nor outputs to annotate", - ) - mark_node_as_annotated( - root_node, - {}, - None, - ) - return - - shared_nodes, adjacent_qspecs = self._get_shared_clique(root_node) - node_order = {node: index for index, node in enumerate(root_node.graph.nodes)} - ordered_nodes = sorted(shared_nodes, key=lambda node: node_order.get(node, 0)) - - # The selection of root node for the shared_qspec is important for - # torchao.quantization.pt2e.prepare._create_obs_or_fq_from_qspec: - # 1. For regular QuantizationSpecs, it creates a new observer - # 2. For SharedQuantizationSpecs, it returns the observer created for it's root node - # 3. It handles nodes in the order they appear in graph.nodes - # This means that we need to make sure that the root node of the shared_qspec - # has an input node with a quantization spec, so that an observer is created. - - if len(adjacent_qspecs) > 0: - # Warn if multiple different adjacent qspecs are found. - if len(adjacent_qspecs) > 1: - logger.warning( - f"Multiple adjacent quantization specs found for {', '.join([n.name for n in ordered_nodes])}, all nodes will share the input quantization spec of {root_node.name}." - ) - - root_node_float_inputs = self._get_input_nodes_with_float_output(root_node) - if len(root_node_float_inputs) == 0: - self.report_reject( - ordered_nodes, - "Couldn't find any floating point input to base shared quantization spec on.", - ) - return - root_node_first_input = root_node_float_inputs[0] - - # Make all nodes share qspec with the root node's first input - shared_qspec = SharedQuantizationSpec((root_node_first_input, root_node)) - for node in shared_nodes: - input_qspec_map: dict[Node, Optional[QuantizationSpec]] = { - n: shared_qspec - for n in self._get_input_nodes_with_float_output(node) - } - if len(self._get_user_nodes_with_float_input(node)) == 0: - output_qspec = None - else: - output_qspec = shared_qspec - mark_node_as_annotated( - node, - input_qspec_map, - output_qspec, - ) - - # Force the root qspec to be the adjacent spec - root_node.meta[Q_ANNOTATION_KEY].input_qspec_map[root_node_first_input] = ( - adjacent_qspecs[0] - ) - self.report_accept(ordered_nodes) - - else: - self.report_reject( - ordered_nodes, - "Couldn't find any adjacent quantization spec to base shared quantization spec on. You may however quantize these nodes manually if required.", - ) - return - - def annotate(self, model: GraphModule) -> None: - """ - Annotate shared quantization spec for supported ops. - """ - for node in model.graph.nodes: - if node.target in self.targets and not self._is_annotated(node): - self._annotate_shared_cluster(node) - - def validate(self, model: GraphModule) -> bool: - return True diff --git a/backends/cortex_m/quantizer/quantizer_reporter.py b/backends/cortex_m/quantizer/quantizer_reporter.py index 8c5151a5c7f..57922646d25 100644 --- a/backends/cortex_m/quantizer/quantizer_reporter.py +++ b/backends/cortex_m/quantizer/quantizer_reporter.py @@ -2,8 +2,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -""" -Contains classes for reporting quantization decisions made by Quantizers. +"""Contains classes for reporting quantization decisions made by Quantizers. Basic useage: 1. Implement the QuantizerReporterUser API for all quantizers intending to use the reporter. @@ -11,6 +10,7 @@ 3. After annotation, log the report using QuantizerReporter.log_quantizer_report(model). Logs a summary report at INFO level, and a detailed node-per-node report at DEBUG level. + """ from __future__ import annotations @@ -18,50 +18,61 @@ import logging from typing import Dict, List, NamedTuple, Optional -from executorch.backends.cortex_m.quantizer.quantization_configs import ( - __name__ as quantization_configs_module, - INT8_ACTIVATION_PER_CHANNEL_QSPEC, - INT8_ACTIVATION_PER_TENSOR_QSPEC, - INT8_PER_CHANNEL_CONFIG, - INT8_PER_TENSOR_CONFIG, - INT8_WEIGHT_PER_CHANNEL_QSPEC, - INT8_WEIGHT_PER_CHANNEL_TRANSPOSE_QSPEC, - INT8_WEIGHT_PER_TENSOR_QSPEC, - SOFTMAX_OUTPUT_FIXED_QSPEC, -) from tabulate import tabulate from torch.fx import GraphModule, Node -from torchao.quantization.pt2e.quantizer import Quantizer +from torchao.quantization.pt2e.quantizer import ( + DerivedQuantizationSpec, + QuantizationSpec, + QuantizationSpecBase, + Quantizer, + SharedQuantizationSpec, +) from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY logger = logging.getLogger(__name__) # Look-up dicts used to get human readable names for supported quantization configs and specs -SUPPORTED_QCONFIGS = { - INT8_PER_CHANNEL_CONFIG: f"{quantization_configs_module}.INT8_PER_CHANNEL_QCONFIG", - INT8_PER_TENSOR_CONFIG: f"{quantization_configs_module}.INT8_PER_TENSOR_QCONFIG", -} +SUPPORTED_QCONFIGS: dict[any, str] = {} +SUPPORTED_QSPECS: dict[QuantizationSpecBase | None, str] = {} -SUPPORTED_QSPECS = { - INT8_ACTIVATION_PER_TENSOR_QSPEC: "INT8_ACTIVATION_PER_TENSOR_QSPEC", - INT8_ACTIVATION_PER_CHANNEL_QSPEC: "INT8_ACTIVATION_PER_CHANNEL_QSPEC", - INT8_WEIGHT_PER_TENSOR_QSPEC: "INT8_WEIGHT_PER_TENSOR_QSPEC", - INT8_WEIGHT_PER_CHANNEL_QSPEC: "INT8_WEIGHT_PER_CHANNEL_QSPEC", - INT8_WEIGHT_PER_CHANNEL_TRANSPOSE_QSPEC: "INT8_WEIGHT_PER_CHANNEL_TRANSPOSE_QSPEC", - SOFTMAX_OUTPUT_FIXED_QSPEC: "SOFTMAX_OUTPUT_FIXED_QSPEC", - None: "None", -} +def _qspec_repr(qspec): + """Get a human readable representation of QuantizationSpecs. + Note that the observer_or_fake_quant_ctr field is created dynamically with + the qspec so two qspecs created at different times will not evaluate as + equal. Therefore a custom comparison is required. -def _qspec_repr(qspec): - return SUPPORTED_QSPECS.get(qspec, "CUSTOM_QSPEC") + #TODO: Clean up qconfig/ qspec string representation logic in cortex_m/arm + backend. + + """ + if isinstance(qspec, SharedQuantizationSpec): + return "SHARED_QSPEC" + elif isinstance(qspec, DerivedQuantizationSpec): + return "DERIVED_QSPEC" + elif qspec is None: + return "NO_QSPEC" + elif isinstance(qspec, QuantizationSpec): + for key, val in SUPPORTED_QSPECS.items(): + if type(qspec) is not type(key): + continue + if qspec.dtype != key.dtype: + continue + if qspec.quant_min != key.quant_min: + continue + if qspec.quant_max != key.quant_max: + continue + if qspec.qscheme != key.qscheme: + continue + if qspec.is_dynamic != key.is_dynamic: + continue + return val + return "UNREGISTERED_QSPEC" class QuantizerInfo(NamedTuple): - """ - NamedTuple storing information about a Quantizer. - """ + """NamedTuple storing information about a Quantizer.""" name: str targeted_nodes_description: str @@ -70,9 +81,7 @@ class QuantizerInfo(NamedTuple): class NodeQSpecReport(NamedTuple): - """ - NamedTuple storing annotation info for a single node. - """ + """NamedTuple storing annotation info for a single node.""" name: str qspec_input_map_lines: List[str] @@ -80,26 +89,24 @@ class NodeQSpecReport(NamedTuple): class AnnotatedPatternReport(NamedTuple): - """ - NamedTuple storing annotation info for a pattern of nodes. - """ + """NamedTuple storing annotation info for a pattern of nodes.""" nodes: List[NodeQSpecReport] class RejectedPatternReport(NamedTuple): - """ - NamedTuple storing rejection info for a pattern of nodes. - """ + """NamedTuple storing rejection info for a pattern of nodes.""" node_names: List[str] rejection_reason: str class QuantizerReport: - """ - Reporter class for collecting and generating quantization reports from a single Quantizer. + """Reporter class for collecting and generating quantization reports from a + single Quantizer. + Used by the QuantizerReporter to aggregate reports from multiple Quantizers. + """ _PREVIOUS_ANNOTATION_REJECT_REASON = "Tried annotating already quantized node." @@ -130,8 +137,8 @@ def rejected_previous_annotation_count(self) -> int: ) def report_accept(self, pattern: List[Node]) -> None: - """ - Stores an AnnotatedPatternReport containing info about the accepted pattern. + """Stores an AnnotatedPatternReport containing info about the accepted + pattern. """ node_reports = [] for node in pattern: @@ -156,8 +163,8 @@ def report_accept(self, pattern: List[Node]) -> None: self.accepted_patterns.append(AnnotatedPatternReport(node_reports)) def report_reject(self, pattern, reason): - """ - Stores an RejectedPatternReport containing info about the rejected pattern. + """Stores an RejectedPatternReport containing info about the rejected + pattern. """ self.rejected_patterns.append( RejectedPatternReport([node.name for node in pattern], reason) @@ -278,19 +285,21 @@ def _rejected_pattern_label(self, rejected: RejectedPatternReport) -> str: class QuantizerReporter: - """ - Reporter class for collecting and generating quantization reports from Quantizers - inheriting from QuantizerReporterUser. + """Reporter class for collecting and generating quantization reports from + Quantizers inheriting from QuantizerReporterUser. """ - def __init__(self, quantizers: List[QuantizerReporterUser]): + def __init__( + self, + quantizers: List[QuantizerReporterUser], + report_title: str = "QUANTIZATION REPORT", + ): self.quantizers: Dict[Quantizer, QuantizerReport] = {} + self.report_title = report_title self.set_quantizers(quantizers) def set_quantizers(self, quantizers: List[QuantizerReporterUser]) -> None: - """ - Registers quantizers to report their quantization decisions. - """ + """Registers quantizers to report their quantization decisions.""" self.quantizers = {} for quantizer in quantizers: @@ -306,8 +315,8 @@ def set_quantizers(self, quantizers: List[QuantizerReporterUser]) -> None: def report_reject( self, quantizer: QuantizerReporterUser, pattern: List[Node], reason: str ): - """ - Reports a node pattern rejected by a quantizer with a given reason. + """Reports a node pattern rejected by a quantizer with a given + reason. """ quantizer_entry = self.quantizers.get(quantizer, None) if quantizer_entry is not None: @@ -322,9 +331,7 @@ def report_accept( quantizer: QuantizerReporterUser, pattern: List[Node], ): - """ - Reports a node pattern accepted by a quantizer. - """ + """Reports a node pattern accepted by a quantizer.""" quantizer_entry = self.quantizers.get(quantizer, None) if quantizer_entry is not None: quantizer_entry.report_accept(pattern) @@ -334,11 +341,12 @@ def report_accept( ) def log_quantizer_report(self, model: Optional[GraphModule] = None): - """ - Logs the quantization report for all registered quantizers. + """Logs the quantization report for all registered quantizers. + + If the logger is set to DEBUG level, a node-per-node report is generated + and logged at DEBUG level. Otherwise, a summary report is logged at INFO + level. - If the logger is set to DEBUG level, a node-per-node report is generated and - logged at DEBUG level. Otherwise, a summary report is logged at INFO level. """ extended_report = logger.isEnabledFor(logging.DEBUG) @@ -351,13 +359,16 @@ def log_quantizer_report(self, model: Optional[GraphModule] = None): def get_quantization_report( self, model: Optional[GraphModule], extended_report: bool ) -> str: - """ - Generates the quantization report for all registered quantizers - """ + """Generates the quantization report for all registered quantizers.""" report_rows: List[str] = [] separator = "-" * 100 report_rows.append(separator) - report_rows.append(" " * 39 + " QUANTIZATION REPORT " + " " * 40) + assert ( + len(self.report_title) < 100 + ), "Report title must be less than 100 characters to be properly formatted in the report header." + pre_pad = (100 - len(self.report_title)) // 2 + post_pad = 100 - len(self.report_title) - pre_pad + report_rows.append(" " * pre_pad + f"{self.report_title}" + " " * post_pad) report_rows.append(separator) for report in self.quantizers.values(): @@ -373,8 +384,8 @@ def get_quantization_report( def unannotated_nodes_report( self, model: Optional[GraphModule], extended_report: bool ) -> List[str]: - """ - Generates the quantization report for all non-annotated nodes in the model. + """Generates the quantization report for all non-annotated nodes in the + model. """ non_quantized_nodes = [ node for node in model.graph.nodes if Q_ANNOTATION_KEY not in node.meta @@ -400,38 +411,36 @@ def _pattern_repr(self, nodes: List[Node]) -> str: class QuantizerReporterUser: - """ - Mixin class for Quantizers, to be used with QuantizerReporter. + """Mixin class for Quantizers, to be used with QuantizerReporter. + + Handles reporter registration and ensures that that the quantizer does not + crash without a reporter registred - Handles reporter registration and ensures that that the quantizer does not crash - without a reporter registred """ def __init__(self): self.reporter: QuantizerReporter = None def register_reporter(self, reporter: QuantizerReporter) -> None: - """ - Used by QuantizerReporter to register itself with the Quantizer. - """ + """Used by QuantizerReporter to register itself with the Quantizer.""" self.reporter = reporter def report_reject(self, pattern: List[Node], reason: str) -> None: - """ - Reports a node pattern rejected by a quantizer, if a reporter is registered. + """Reports a node pattern rejected by a quantizer, if a reporter is + registered. """ if self.reporter is not None: self.reporter.report_reject(self, pattern, reason) def report_accept(self, pattern: List[Node]) -> None: - """ - Reports a node pattern accepted by a quantizer, if a reporter is registered. + """Reports a node pattern accepted by a quantizer, if a reporter is + registered. """ if self.reporter is not None: self.reporter.report_accept(self, pattern) def get_quantizer_info(self) -> "QuantizerInfo": - """ - Returns a QuantizerInfo NamedTuple with information about the quantizer. + """Returns a QuantizerInfo NamedTuple with information about the + quantizer. """ raise NotImplementedError("Quantizer must implement get_quantizer_info method.") diff --git a/backends/cortex_m/test/misc/test_portable_int8.py b/backends/cortex_m/test/misc/test_portable_int8.py index b3565882f53..8e08250bb37 100644 --- a/backends/cortex_m/test/misc/test_portable_int8.py +++ b/backends/cortex_m/test/misc/test_portable_int8.py @@ -14,11 +14,9 @@ import torch from executorch.backends.arm._passes import FoldAndAnnotateQParamsPass from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor +from executorch.backends.arm.quantizer.arm_quantizer_utils import SharedQspecQuantizer from executorch.backends.arm.test.common import parametrize -from executorch.backends.cortex_m.quantizer.quantizer import ( - CortexMQuantizer, - SharedQspecQuantizer, -) +from executorch.backends.cortex_m.quantizer.quantizer import CortexMQuantizer from executorch.backends.cortex_m.test.tester import CortexMTester from executorch.backends.test.harness.stages import StageType from executorch.exir import EdgeCompileConfig diff --git a/backends/cortex_m/test/misc/test_quantizer_reporter.py b/backends/cortex_m/test/misc/test_quantizer_reporter.py index 354a1258f1f..368ff78793c 100644 --- a/backends/cortex_m/test/misc/test_quantizer_reporter.py +++ b/backends/cortex_m/test/misc/test_quantizer_reporter.py @@ -101,11 +101,13 @@ def test_debug_log_level(caplog): add_node, {add_node.args[0]: INT8_WEIGHT_PER_TENSOR_QSPEC, add_node.args[1]: None}, None, + is_quantized=True, ) mark_node_as_annotated( relu_node, {}, INT8_ACTIVATION_PER_CHANNEL_QSPEC, + is_quantized=True, ) quantizer1.report_accept([add_node, relu_node]) quantizer2.report_reject( @@ -128,8 +130,8 @@ def test_debug_log_level(caplog): NODE NAME INPUT QSPEC MAP OUTPUT QSPEC MAP -- ----------- ------------------------------- --------------------------------- - ╒ add x: INT8_WEIGHT_PER_TENSOR_QSPEC None - | y: None + ╒ add x: INT8_WEIGHT_PER_TENSOR_QSPEC NO_QSPEC + | y: NO_QSPEC ╘ relu INT8_ACTIVATION_PER_CHANNEL_QSPEC ---------------------------------------------------------------------------------------------------- DummyQuantizer using dummy nodes