From d810bc39c6af120a259fcf83d07184458d642994 Mon Sep 17 00:00:00 2001 From: Rahul Chandra Date: Mon, 8 Dec 2025 14:06:36 -0800 Subject: [PATCH 1/4] Adding Test for CadenceWith16BitMatmulActivationsQuantizer (#16089) Summary: We test the quantizer we added in D87996796 correctly annotates the graph. We use the graph builder to build the graph with metadata(that's needed for quantizer.annotate to recognize the nodes), and we ensure that the quantization params are as expected. Reviewed By: zonglinpeng, hsharma35 Differential Revision: D88053808 --- backends/cadence/aot/TARGETS | 3 ++ .../cadence/aot/tests/test_quantizer_ops.py | 50 +++++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/backends/cadence/aot/TARGETS b/backends/cadence/aot/TARGETS index be74a8d957f..8363f022946 100644 --- a/backends/cadence/aot/TARGETS +++ b/backends/cadence/aot/TARGETS @@ -641,6 +641,9 @@ python_unittest( typing = True, deps = [ "//caffe2:torch", + "//executorch/backends/cadence/aot:graph_builder", "//executorch/backends/cadence/aot/quantizer:quantizer", + "//executorch/exir:pass_base", + "//pytorch/ao:torchao", ], ) diff --git a/backends/cadence/aot/tests/test_quantizer_ops.py b/backends/cadence/aot/tests/test_quantizer_ops.py index f0df592558f..c7a4cc0f2dc 100644 --- a/backends/cadence/aot/tests/test_quantizer_ops.py +++ b/backends/cadence/aot/tests/test_quantizer_ops.py @@ -9,14 +9,64 @@ import unittest import torch +from executorch.backends.cadence.aot.graph_builder import GraphBuilder from executorch.backends.cadence.aot.quantizer.patterns import AddmmPattern from executorch.backends.cadence.aot.quantizer.quantizer import ( CadenceAtenQuantizer, CadenceDefaultQuantizer, CadenceW8A32MixedQuantizer, + CadenceWith16BitMatmulActivationsQuantizer, + qconfig_A16, qconfig_A8W8, ) +from executorch.exir.pass_base import NodeMetadata +from torchao.quantization.pt2e.quantizer.quantizer import ( + Q_ANNOTATION_KEY, + QuantizationAnnotation, +) + + +class QuantizerAnnotationTest(unittest.TestCase): + """Unit tests for verifying quantizer annotations are correctly applied.""" + + def _build_matmul_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: + """Build a simple graph with a matmul operation.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(4, 8)) + y = builder.placeholder("y", torch.randn(8, 4)) + matmul = builder.call_operator( + op=torch.ops.aten.matmul.default, + args=(x, y), + meta=NodeMetadata( + {"source_fn_stack": [("matmul", torch.ops.aten.matmul.default)]} + ), + ) + builder.output([matmul]) + gm = builder.get_graph_module() + + matmul_nodes = gm.graph.find_nodes( + op="call_function", + target=torch.ops.aten.matmul.default, + ) + self.assertEqual(len(matmul_nodes), 1, "Should find exactly one matmul node") + return gm, matmul_nodes[0] + + def test_matmul_16bit_quantizer_annotation(self) -> None: + """Test that CadenceWith16BitMatmulActivationsQuantizer correctly annotates matmul.""" + gm, matmul_node = self._build_matmul_graph() + + quantizer = CadenceWith16BitMatmulActivationsQuantizer() + quantizer.annotate(gm) + + annotation: QuantizationAnnotation = matmul_node.meta[Q_ANNOTATION_KEY] + self.assertTrue(annotation._annotated) + + self.assertEqual(annotation.output_qspec, qconfig_A16.output_activation) + + self.assertEqual(len(annotation.input_qspec_map), 2) + for _, input_qspec in annotation.input_qspec_map.items(): + self.assertEqual(input_qspec, qconfig_A16.input_activation) class QuantizerOpsPreserveTest(unittest.TestCase): From d300b0f8b74a3d9f7aefdaebb0d8dede3a207c53 Mon Sep 17 00:00:00 2001 From: Rahul Chandra Date: Mon, 8 Dec 2025 14:06:36 -0800 Subject: [PATCH 2/4] Adding Test For CadenceWith16BitLinearActivationsQuantizer (#16097) Summary: We test the CadenceWith16BitLinearActivationQuantizer. We use the graph builder to build the graph with metadata(that's needed for quantizer.annotate to recognize the nodes), and we ensure that the quantization params are as expected. Reviewed By: zonglinpeng, hsharma35 Differential Revision: D88054651 --- .../cadence/aot/tests/test_quantizer_ops.py | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/backends/cadence/aot/tests/test_quantizer_ops.py b/backends/cadence/aot/tests/test_quantizer_ops.py index c7a4cc0f2dc..d1f5389ddb1 100644 --- a/backends/cadence/aot/tests/test_quantizer_ops.py +++ b/backends/cadence/aot/tests/test_quantizer_ops.py @@ -16,6 +16,7 @@ CadenceAtenQuantizer, CadenceDefaultQuantizer, CadenceW8A32MixedQuantizer, + CadenceWith16BitLinearActivationsQuantizer, CadenceWith16BitMatmulActivationsQuantizer, qconfig_A16, qconfig_A8W8, @@ -68,6 +69,51 @@ def test_matmul_16bit_quantizer_annotation(self) -> None: for _, input_qspec in annotation.input_qspec_map.items(): self.assertEqual(input_qspec, qconfig_A16.input_activation) + def _build_linear_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: + """Build a simple graph with a linear operation (no bias).""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(1, 10)) + weight = builder.placeholder("weight", torch.randn(5, 10)) + linear = builder.call_operator( + op=torch.ops.aten.linear.default, + args=(x, weight), + meta=NodeMetadata( + {"source_fn_stack": [("linear", torch.ops.aten.linear.default)]} + ), + ) + builder.output([linear]) + gm = builder.get_graph_module() + + linear_nodes = gm.graph.find_nodes( + op="call_function", + target=torch.ops.aten.linear.default, + ) + self.assertEqual(len(linear_nodes), 1, "Should find exactly one linear node") + return gm, linear_nodes[0] + + def test_linear_16bit_quantizer_annotation(self) -> None: + """Test that CadenceWith16BitLinearActivationsQuantizer correctly annotates linear.""" + gm, linear_node = self._build_linear_graph() + + quantizer = CadenceWith16BitLinearActivationsQuantizer() + quantizer.annotate(gm) + + annotation: QuantizationAnnotation = linear_node.meta[Q_ANNOTATION_KEY] + self.assertTrue(annotation._annotated) + + # Verify output is annotated with qconfig_A16.output_activation (INT16) + self.assertEqual(annotation.output_qspec, qconfig_A16.output_activation) + + # Verify inputs: activation (INT16) and weight (INT8) + self.assertEqual(len(annotation.input_qspec_map), 2) + for input_node, input_qspec in annotation.input_qspec_map.items(): + if input_node == linear_node.args[0]: + # Activation input - should be INT16 + self.assertEqual(input_qspec, qconfig_A16.input_activation) + elif input_node == linear_node.args[1]: + # Weight - should be INT8 + self.assertEqual(input_qspec, qconfig_A16.weight) + class QuantizerOpsPreserveTest(unittest.TestCase): def test_mixed_w8a32_ops_to_preserve(self) -> None: From 7c2f3fbde3fba7a9a380544d418fc745ac3723d2 Mon Sep 17 00:00:00 2001 From: Rahul Chandra Date: Mon, 8 Dec 2025 14:06:36 -0800 Subject: [PATCH 3/4] Creating Paramaterized Test For Quantizers For Easier Testing (#16098) Summary: We consolidate the two tests we created into a single testing function using parameterization. This will make testing future Quantizers much easier, and there will be a lot less code duplication. Reviewed By: hsharma35, zonglinpeng Differential Revision: D88054917 --- backends/cadence/aot/TARGETS | 1 + .../cadence/aot/tests/test_quantizer_ops.py | 98 ++++++++++++------- 2 files changed, 66 insertions(+), 33 deletions(-) diff --git a/backends/cadence/aot/TARGETS b/backends/cadence/aot/TARGETS index 8363f022946..8c24fd88af4 100644 --- a/backends/cadence/aot/TARGETS +++ b/backends/cadence/aot/TARGETS @@ -640,6 +640,7 @@ python_unittest( ], typing = True, deps = [ + "fbsource//third-party/pypi/parameterized:parameterized", "//caffe2:torch", "//executorch/backends/cadence/aot:graph_builder", "//executorch/backends/cadence/aot/quantizer:quantizer", diff --git a/backends/cadence/aot/tests/test_quantizer_ops.py b/backends/cadence/aot/tests/test_quantizer_ops.py index d1f5389ddb1..8f69eb270ce 100644 --- a/backends/cadence/aot/tests/test_quantizer_ops.py +++ b/backends/cadence/aot/tests/test_quantizer_ops.py @@ -7,6 +7,7 @@ # pyre-strict import unittest +from typing import Callable import torch from executorch.backends.cadence.aot.graph_builder import GraphBuilder @@ -15,6 +16,7 @@ from executorch.backends.cadence.aot.quantizer.quantizer import ( CadenceAtenQuantizer, CadenceDefaultQuantizer, + CadenceQuantizer, CadenceW8A32MixedQuantizer, CadenceWith16BitLinearActivationsQuantizer, CadenceWith16BitMatmulActivationsQuantizer, @@ -22,11 +24,19 @@ qconfig_A8W8, ) from executorch.exir.pass_base import NodeMetadata +from parameterized import parameterized +from torch._ops import OpOverload from torchao.quantization.pt2e.quantizer.quantizer import ( Q_ANNOTATION_KEY, QuantizationAnnotation, + QuantizationSpec, ) +# Type alias for graph builder functions +GraphBuilderFn = Callable[ + ["QuantizerAnnotationTest"], tuple[torch.fx.GraphModule, torch.fx.Node] +] + class QuantizerAnnotationTest(unittest.TestCase): """Unit tests for verifying quantizer annotations are correctly applied.""" @@ -53,22 +63,6 @@ def _build_matmul_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: self.assertEqual(len(matmul_nodes), 1, "Should find exactly one matmul node") return gm, matmul_nodes[0] - def test_matmul_16bit_quantizer_annotation(self) -> None: - """Test that CadenceWith16BitMatmulActivationsQuantizer correctly annotates matmul.""" - gm, matmul_node = self._build_matmul_graph() - - quantizer = CadenceWith16BitMatmulActivationsQuantizer() - quantizer.annotate(gm) - - annotation: QuantizationAnnotation = matmul_node.meta[Q_ANNOTATION_KEY] - self.assertTrue(annotation._annotated) - - self.assertEqual(annotation.output_qspec, qconfig_A16.output_activation) - - self.assertEqual(len(annotation.input_qspec_map), 2) - for _, input_qspec in annotation.input_qspec_map.items(): - self.assertEqual(input_qspec, qconfig_A16.input_activation) - def _build_linear_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: """Build a simple graph with a linear operation (no bias).""" builder = GraphBuilder() @@ -91,28 +85,66 @@ def _build_linear_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: self.assertEqual(len(linear_nodes), 1, "Should find exactly one linear node") return gm, linear_nodes[0] - def test_linear_16bit_quantizer_annotation(self) -> None: - """Test that CadenceWith16BitLinearActivationsQuantizer correctly annotates linear.""" - gm, linear_node = self._build_linear_graph() + @parameterized.expand( + [ + ( + "matmul_A16", + lambda self: self._build_matmul_graph(), + CadenceWith16BitMatmulActivationsQuantizer(), + torch.ops.aten.matmul.default, + qconfig_A16.output_activation, + # For matmul, both inputs are activations + [qconfig_A16.input_activation, qconfig_A16.input_activation], + ), + ( + "linear_A16", + lambda self: self._build_linear_graph(), + CadenceWith16BitLinearActivationsQuantizer(), + torch.ops.aten.linear.default, + qconfig_A16.output_activation, + # For linear: [input_activation, weight] + [qconfig_A16.input_activation, qconfig_A16.weight], + ), + ] + ) + def test_quantizer_annotation( + self, + name: str, + graph_builder_fn: GraphBuilderFn, + quantizer: CadenceQuantizer, + target: OpOverload, + expected_output_qspec: QuantizationSpec, + expected_input_qspecs: list[QuantizationSpec], + ) -> None: + """Parameterized test for quantizer annotations.""" + gm, op_node = graph_builder_fn(self) - quantizer = CadenceWith16BitLinearActivationsQuantizer() quantizer.annotate(gm) - annotation: QuantizationAnnotation = linear_node.meta[Q_ANNOTATION_KEY] + annotation: QuantizationAnnotation = op_node.meta[Q_ANNOTATION_KEY] self.assertTrue(annotation._annotated) - # Verify output is annotated with qconfig_A16.output_activation (INT16) - self.assertEqual(annotation.output_qspec, qconfig_A16.output_activation) - - # Verify inputs: activation (INT16) and weight (INT8) - self.assertEqual(len(annotation.input_qspec_map), 2) - for input_node, input_qspec in annotation.input_qspec_map.items(): - if input_node == linear_node.args[0]: - # Activation input - should be INT16 - self.assertEqual(input_qspec, qconfig_A16.input_activation) - elif input_node == linear_node.args[1]: - # Weight - should be INT8 - self.assertEqual(input_qspec, qconfig_A16.weight) + # Verify output annotation + self.assertEqual(annotation.output_qspec, expected_output_qspec) + + # Verify input annotations + # Build actual_specs in the fixed order defined by op_node.args + self.assertEqual(len(annotation.input_qspec_map), len(expected_input_qspecs)) + actual_specs = [] + for i in range(len(expected_input_qspecs)): + arg = op_node.args[i] + assert isinstance(arg, torch.fx.Node) + actual_specs.append(annotation.input_qspec_map[arg]) + + # Compare expected vs actual specs + for i, (expected, actual) in enumerate( + zip(expected_input_qspecs, actual_specs) + ): + self.assertEqual( + actual, + expected, + f"Input qspec mismatch at index {i}", + ) class QuantizerOpsPreserveTest(unittest.TestCase): From 34d959726a620654690045a3c1d1e7cf3972347e Mon Sep 17 00:00:00 2001 From: Rahul Chandra Date: Mon, 8 Dec 2025 14:06:36 -0800 Subject: [PATCH 4/4] Adding Test To Ensure All Future Quantizers Are Tested (#16099) Summary: We first create a list of quantizers that are currently not tested(we'll slowly reduce this to 0), and then we create a test to ensure that all future quantizers get tested using this framework. In order to do this, we needed to refactor how the current test is setup, specifically the parameterization. Reviewed By: mcremon-meta, zonglinpeng, hsharma35 Differential Revision: D88055443 --- .../cadence/aot/tests/test_quantizer_ops.py | 135 +++++++++++++----- 1 file changed, 100 insertions(+), 35 deletions(-) diff --git a/backends/cadence/aot/tests/test_quantizer_ops.py b/backends/cadence/aot/tests/test_quantizer_ops.py index 8f69eb270ce..c0c926548ed 100644 --- a/backends/cadence/aot/tests/test_quantizer_ops.py +++ b/backends/cadence/aot/tests/test_quantizer_ops.py @@ -6,20 +6,28 @@ # pyre-strict +import inspect import unittest from typing import Callable import torch from executorch.backends.cadence.aot.graph_builder import GraphBuilder +from executorch.backends.cadence.aot.quantizer import quantizer as quantizer_module from executorch.backends.cadence.aot.quantizer.patterns import AddmmPattern from executorch.backends.cadence.aot.quantizer.quantizer import ( CadenceAtenQuantizer, CadenceDefaultQuantizer, + CadenceFusedConvReluQuantizer, + CadenceNopQuantizer, CadenceQuantizer, CadenceW8A32MixedQuantizer, + CadenceWakeWordQuantizer, + CadenceWith16BitConvActivationsQuantizer, CadenceWith16BitLinearActivationsQuantizer, CadenceWith16BitMatmulActivationsQuantizer, + CadenceWithLayerNormQuantizer, + CadenceWithSoftmaxQuantizer, qconfig_A16, qconfig_A8W8, ) @@ -32,12 +40,67 @@ QuantizationSpec, ) -# Type alias for graph builder functions +# Type alias for graph builder functions. +# These functions take a test instance and return a graph module and the target op node. GraphBuilderFn = Callable[ ["QuantizerAnnotationTest"], tuple[torch.fx.GraphModule, torch.fx.Node] ] +# Quantizers intentionally excluded from annotation testing. +# These should be explicitly justified when added. +EXCLUDED_FROM_ANNOTATION_TESTING: set[type[CadenceQuantizer]] = { + CadenceDefaultQuantizer, # TODO: T247438143 Add test coverage + CadenceFusedConvReluQuantizer, # TODO: T247438151 Add test coverage + CadenceNopQuantizer, # No-op quantizer, doesn't annotate anything + CadenceW8A32MixedQuantizer, # TODO: T247438158 Add test coverage + CadenceWakeWordQuantizer, # TODO: T247438162 Add test coverage + CadenceWith16BitConvActivationsQuantizer, # TODO: T247438221 Add test coverage + CadenceWithLayerNormQuantizer, # TODO: T247438410 Add test coverage + CadenceWithSoftmaxQuantizer, # TODO: T247438418 Add test coverage +} + + +# Test case definitions for quantizer annotation tests. +# Format: (name, graph_builder_fn, quantizer_instance, target_op, expected_output_qspec, expected_input_qspecs) +# Adding a new quantizer test only requires adding a tuple to this list. +QUANTIZER_ANNOTATION_TEST_CASES: list[ + tuple[ + str, + GraphBuilderFn, + CadenceQuantizer, + OpOverload, + QuantizationSpec, + list[QuantizationSpec], + ] +] = [ + ( + "matmul_A16", + lambda self: self._build_matmul_graph(), + CadenceWith16BitMatmulActivationsQuantizer(), + torch.ops.aten.matmul.default, + qconfig_A16.output_activation, + # For matmul, both inputs are activations + [qconfig_A16.input_activation, qconfig_A16.input_activation], + ), + ( + "linear_A16", + lambda self: self._build_linear_graph(), + CadenceWith16BitLinearActivationsQuantizer(), + torch.ops.aten.linear.default, + qconfig_A16.output_activation, + # For linear: [input_activation, weight] + [qconfig_A16.input_activation, qconfig_A16.weight], + ), +] + +# Derive the set of tested quantizer classes from the test cases. +# This ensures TESTED_QUANTIZER_CLASSES stays in sync with actual tests. +TESTED_QUANTIZER_CLASSES: set[type[CadenceQuantizer]] = { + type(case[2]) for case in QUANTIZER_ANNOTATION_TEST_CASES +} + + class QuantizerAnnotationTest(unittest.TestCase): """Unit tests for verifying quantizer annotations are correctly applied.""" @@ -85,28 +148,7 @@ def _build_linear_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: self.assertEqual(len(linear_nodes), 1, "Should find exactly one linear node") return gm, linear_nodes[0] - @parameterized.expand( - [ - ( - "matmul_A16", - lambda self: self._build_matmul_graph(), - CadenceWith16BitMatmulActivationsQuantizer(), - torch.ops.aten.matmul.default, - qconfig_A16.output_activation, - # For matmul, both inputs are activations - [qconfig_A16.input_activation, qconfig_A16.input_activation], - ), - ( - "linear_A16", - lambda self: self._build_linear_graph(), - CadenceWith16BitLinearActivationsQuantizer(), - torch.ops.aten.linear.default, - qconfig_A16.output_activation, - # For linear: [input_activation, weight] - [qconfig_A16.input_activation, qconfig_A16.weight], - ), - ] - ) + @parameterized.expand(QUANTIZER_ANNOTATION_TEST_CASES) def test_quantizer_annotation( self, name: str, @@ -128,24 +170,47 @@ def test_quantizer_annotation( self.assertEqual(annotation.output_qspec, expected_output_qspec) # Verify input annotations - # Build actual_specs in the fixed order defined by op_node.args self.assertEqual(len(annotation.input_qspec_map), len(expected_input_qspecs)) - actual_specs = [] - for i in range(len(expected_input_qspecs)): - arg = op_node.args[i] - assert isinstance(arg, torch.fx.Node) - actual_specs.append(annotation.input_qspec_map[arg]) - - # Compare expected vs actual specs - for i, (expected, actual) in enumerate( - zip(expected_input_qspecs, actual_specs) + for i, (input_node, input_qspec) in enumerate( + annotation.input_qspec_map.items() ): + expected_arg = op_node.args[i] + assert isinstance(expected_arg, torch.fx.Node) + self.assertEqual( + input_node, + expected_arg, + f"Input node mismatch at index {i}", + ) self.assertEqual( - actual, - expected, + input_qspec, + expected_input_qspecs[i], f"Input qspec mismatch at index {i}", ) + def test_all_quantizers_have_annotation_tests(self) -> None: + """Ensure every CadenceQuantizer subclass is either tested or explicitly excluded.""" + # Get all CadenceQuantizer subclasses defined in the quantizer module + all_quantizers: set[type[CadenceQuantizer]] = set() + for _, obj in inspect.getmembers(quantizer_module, inspect.isclass): + if ( + issubclass(obj, CadenceQuantizer) + and obj is not CadenceQuantizer + and obj.__module__ == quantizer_module.__name__ + ): + all_quantizers.add(obj) + + # Check for missing tests + untested = ( + all_quantizers - TESTED_QUANTIZER_CLASSES - EXCLUDED_FROM_ANNOTATION_TESTING + ) + if untested: + untested_names = sorted(cls.__name__ for cls in untested) + self.fail( + f"The following CadenceQuantizer subclasses are not tested in " + f"test_quantizer_annotation and not in EXCLUDED_FROM_ANNOTATION_TESTING: " + f"{untested_names}. Please add test cases or explicitly exclude them." + ) + class QuantizerOpsPreserveTest(unittest.TestCase): def test_mixed_w8a32_ops_to_preserve(self) -> None: