Skip to content

Commit 6bd81e9

Browse files
RahulC7facebook-github-bot
authored andcommitted
Adding Test To Ensure All Future Quantizers Are Tested (#16099)
Summary: We first create a list of quantizers that are currently not tested(we'll slowly reduce this to 0), and then we create a test to ensure that all future quantizers get tested using this framework. In order to do this, we needed to refactor how the current test is setup, specifically the parameterization. Reviewed By: mcremon-meta, zonglinpeng, hsharma35 Differential Revision: D88055443
1 parent b996edd commit 6bd81e9

File tree

1 file changed

+100
-35
lines changed

1 file changed

+100
-35
lines changed

backends/cadence/aot/tests/test_quantizer_ops.py

Lines changed: 100 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,28 @@
66

77
# pyre-strict
88

9+
import inspect
910
import unittest
1011
from typing import Callable
1112

1213
import torch
1314
from executorch.backends.cadence.aot.graph_builder import GraphBuilder
15+
from executorch.backends.cadence.aot.quantizer import quantizer as quantizer_module
1416
from executorch.backends.cadence.aot.quantizer.patterns import AddmmPattern
1517

1618
from executorch.backends.cadence.aot.quantizer.quantizer import (
1719
CadenceAtenQuantizer,
1820
CadenceDefaultQuantizer,
21+
CadenceFusedConvReluQuantizer,
22+
CadenceNopQuantizer,
1923
CadenceQuantizer,
2024
CadenceW8A32MixedQuantizer,
25+
CadenceWakeWordQuantizer,
26+
CadenceWith16BitConvActivationsQuantizer,
2127
CadenceWith16BitLinearActivationsQuantizer,
2228
CadenceWith16BitMatmulActivationsQuantizer,
29+
CadenceWithLayerNormQuantizer,
30+
CadenceWithSoftmaxQuantizer,
2331
qconfig_A16,
2432
qconfig_A8W8,
2533
)
@@ -32,12 +40,67 @@
3240
QuantizationSpec,
3341
)
3442

35-
# Type alias for graph builder functions
43+
# Type alias for graph builder functions.
44+
# These functions take a test instance and return a graph module and the target op node.
3645
GraphBuilderFn = Callable[
3746
["QuantizerAnnotationTest"], tuple[torch.fx.GraphModule, torch.fx.Node]
3847
]
3948

4049

50+
# Quantizers intentionally excluded from annotation testing.
51+
# These should be explicitly justified when added.
52+
EXCLUDED_FROM_ANNOTATION_TESTING: set[type[CadenceQuantizer]] = {
53+
CadenceDefaultQuantizer, # TODO: T247438143 Add test coverage
54+
CadenceFusedConvReluQuantizer, # TODO: T247438151 Add test coverage
55+
CadenceNopQuantizer, # No-op quantizer, doesn't annotate anything
56+
CadenceW8A32MixedQuantizer, # TODO: T247438158 Add test coverage
57+
CadenceWakeWordQuantizer, # TODO: T247438162 Add test coverage
58+
CadenceWith16BitConvActivationsQuantizer, # TODO: T247438221 Add test coverage
59+
CadenceWithLayerNormQuantizer, # TODO: T247438410 Add test coverage
60+
CadenceWithSoftmaxQuantizer, # TODO: T247438418 Add test coverage
61+
}
62+
63+
64+
# Test case definitions for quantizer annotation tests.
65+
# Format: (name, graph_builder_fn, quantizer_instance, target_op, expected_output_qspec, expected_input_qspecs)
66+
# Adding a new quantizer test only requires adding a tuple to this list.
67+
QUANTIZER_ANNOTATION_TEST_CASES: list[
68+
tuple[
69+
str,
70+
GraphBuilderFn,
71+
CadenceQuantizer,
72+
OpOverload,
73+
QuantizationSpec,
74+
list[QuantizationSpec],
75+
]
76+
] = [
77+
(
78+
"matmul_A16",
79+
lambda self: self._build_matmul_graph(),
80+
CadenceWith16BitMatmulActivationsQuantizer(),
81+
torch.ops.aten.matmul.default,
82+
qconfig_A16.output_activation,
83+
# For matmul, both inputs are activations
84+
[qconfig_A16.input_activation, qconfig_A16.input_activation],
85+
),
86+
(
87+
"linear_A16",
88+
lambda self: self._build_linear_graph(),
89+
CadenceWith16BitLinearActivationsQuantizer(),
90+
torch.ops.aten.linear.default,
91+
qconfig_A16.output_activation,
92+
# For linear: [input_activation, weight]
93+
[qconfig_A16.input_activation, qconfig_A16.weight],
94+
),
95+
]
96+
97+
# Derive the set of tested quantizer classes from the test cases.
98+
# This ensures TESTED_QUANTIZER_CLASSES stays in sync with actual tests.
99+
TESTED_QUANTIZER_CLASSES: set[type[CadenceQuantizer]] = {
100+
type(case[2]) for case in QUANTIZER_ANNOTATION_TEST_CASES
101+
}
102+
103+
41104
class QuantizerAnnotationTest(unittest.TestCase):
42105
"""Unit tests for verifying quantizer annotations are correctly applied."""
43106

@@ -85,28 +148,7 @@ def _build_linear_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
85148
self.assertEqual(len(linear_nodes), 1, "Should find exactly one linear node")
86149
return gm, linear_nodes[0]
87150

88-
@parameterized.expand(
89-
[
90-
(
91-
"matmul_A16",
92-
lambda self: self._build_matmul_graph(),
93-
CadenceWith16BitMatmulActivationsQuantizer(),
94-
torch.ops.aten.matmul.default,
95-
qconfig_A16.output_activation,
96-
# For matmul, both inputs are activations
97-
[qconfig_A16.input_activation, qconfig_A16.input_activation],
98-
),
99-
(
100-
"linear_A16",
101-
lambda self: self._build_linear_graph(),
102-
CadenceWith16BitLinearActivationsQuantizer(),
103-
torch.ops.aten.linear.default,
104-
qconfig_A16.output_activation,
105-
# For linear: [input_activation, weight]
106-
[qconfig_A16.input_activation, qconfig_A16.weight],
107-
),
108-
]
109-
)
151+
@parameterized.expand(QUANTIZER_ANNOTATION_TEST_CASES)
110152
def test_quantizer_annotation(
111153
self,
112154
name: str,
@@ -128,24 +170,47 @@ def test_quantizer_annotation(
128170
self.assertEqual(annotation.output_qspec, expected_output_qspec)
129171

130172
# Verify input annotations
131-
# Build actual_specs in the fixed order defined by op_node.args
132173
self.assertEqual(len(annotation.input_qspec_map), len(expected_input_qspecs))
133-
actual_specs = []
134-
for i in range(len(expected_input_qspecs)):
135-
arg = op_node.args[i]
136-
assert isinstance(arg, torch.fx.Node)
137-
actual_specs.append(annotation.input_qspec_map[arg])
138-
139-
# Compare expected vs actual specs
140-
for i, (expected, actual) in enumerate(
141-
zip(expected_input_qspecs, actual_specs)
174+
for i, (input_node, input_qspec) in enumerate(
175+
annotation.input_qspec_map.items()
142176
):
177+
expected_arg = op_node.args[i]
178+
assert isinstance(expected_arg, torch.fx.Node)
179+
self.assertEqual(
180+
input_node,
181+
expected_arg,
182+
f"Input node mismatch at index {i}",
183+
)
143184
self.assertEqual(
144-
actual,
145-
expected,
185+
input_qspec,
186+
expected_input_qspecs[i],
146187
f"Input qspec mismatch at index {i}",
147188
)
148189

190+
def test_all_quantizers_have_annotation_tests(self) -> None:
191+
"""Ensure every CadenceQuantizer subclass is either tested or explicitly excluded."""
192+
# Get all CadenceQuantizer subclasses defined in the quantizer module
193+
all_quantizers: set[type[CadenceQuantizer]] = set()
194+
for _, obj in inspect.getmembers(quantizer_module, inspect.isclass):
195+
if (
196+
issubclass(obj, CadenceQuantizer)
197+
and obj is not CadenceQuantizer
198+
and obj.__module__ == quantizer_module.__name__
199+
):
200+
all_quantizers.add(obj)
201+
202+
# Check for missing tests
203+
untested = (
204+
all_quantizers - TESTED_QUANTIZER_CLASSES - EXCLUDED_FROM_ANNOTATION_TESTING
205+
)
206+
if untested:
207+
untested_names = sorted(cls.__name__ for cls in untested)
208+
self.fail(
209+
f"The following CadenceQuantizer subclasses are not tested in "
210+
f"test_quantizer_annotation and not in EXCLUDED_FROM_ANNOTATION_TESTING: "
211+
f"{untested_names}. Please add test cases or explicitly exclude them."
212+
)
213+
149214

150215
class QuantizerOpsPreserveTest(unittest.TestCase):
151216
def test_mixed_w8a32_ops_to_preserve(self) -> None:

0 commit comments

Comments
 (0)