diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py index 9384bab13f3cc..ea995d4707ba3 100644 --- a/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py +++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py @@ -55,6 +55,7 @@ def get_qnn_qdq_config( stride: int | None = None, calibration_providers: list[str] | None = None, op_types_to_quantize: list[str] | None = None, + nodes_to_exclude: list[str] | None = None, ) -> StaticQuantConfig: """ Returns a static quantization configuration suitable for running QDQ models on QNN EP. @@ -122,6 +123,8 @@ def get_qnn_qdq_config( calibration_providers: Execution providers to run the session during calibration. Default is None which uses [ "CPUExecutionProvider" ]. op_types_to_quantize: If set to None, all operator types will be quantized except for OP_TYPES_TO_EXCLUDE + nodes_to_exclude: List of nodes names to exclude from quantization. The nodes in this list will be excluded from + quantization when it is not None. Returns: A StaticQuantConfig object @@ -167,10 +170,13 @@ def get_qnn_qdq_config( ) op_types_to_quantize_set = set(op_types_to_quantize) if op_types_to_quantize else None + nodes_to_exclude_set = set(nodes_to_exclude) if nodes_to_exclude else None for node in model.graph.node: if op_types_to_quantize_set and node.op_type not in op_types_to_quantize_set: continue + if nodes_to_exclude_set and node.name in nodes_to_exclude_set: + continue op_types.add(node.op_type) qnn_compat.process_node(node) @@ -201,6 +207,7 @@ def get_qnn_qdq_config( op_types_to_quantize=op_types_to_quantize if op_types_to_quantize else list(op_types.difference(OP_TYPES_TO_EXCLUDE)), + nodes_to_exclude=nodes_to_exclude, per_channel=per_channel, use_external_data_format=(model_has_external_data or model.ByteSize() >= MODEL_SIZE_THRESHOLD), calibration_providers=calibration_providers,