From 49d197a8e6368802f3a0c86a52cabeaffb154cc7 Mon Sep 17 00:00:00 2001 From: Yi-Hong Lyu Date: Fri, 10 May 2024 16:07:42 -0700 Subject: [PATCH] Enable ClipQuantFusion exclusively on CPU EP (#20627) ### Motivation and Context The Intel NPU does not support 16-bit int quantized operators. Consequently, the execution provider removes the QuantizeLinear/DeQuantizeLinear (Q/DQ) operators from node units and executes the operation as FP16 in the backend. However, if a Clip operator was fused into a Q operator in the node unit, the removal of Q/DQ operators results in inaccuracies because the effect of the original Clip operators is lost. Consider the following example: - FP32 model: -> Op_FP32 -> Clip -> - QDQ model: -> (DQ-> Op_FP32 -> Q) -> (DQ' -> Clip -> Q') -> - After ClipQuantFusion: -> (DQ-> Op_FP32 -> Q) -> (DQ' -> Q') -> - Intel Execution Provider strips Q/DQ: -> Op_FP16 -> To solve this issue, we have enabled ClipQuantFusion exclusively on the CPU execution provider. --- onnxruntime/core/optimizer/graph_transformer_utils.cc | 3 +-- .../core/optimizer/qdq_transformer/clip_quantizelinear.cc | 4 +++- onnxruntime/test/optimizer/qdq_transformer_test.cc | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 9341ade0a2f1d..66ac676fa2f6d 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -132,14 +132,13 @@ InlinedVector> GenerateRewriteRules( rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); - rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); break; case TransformerLevel::Level2: + rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); - // No level2 rules available today break; case TransformerLevel::Level3: diff --git a/onnxruntime/core/optimizer/qdq_transformer/clip_quantizelinear.cc b/onnxruntime/core/optimizer/qdq_transformer/clip_quantizelinear.cc index 50653b368857d..72ca1cb74f1fd 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/clip_quantizelinear.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/clip_quantizelinear.cc @@ -83,13 +83,15 @@ static bool GetQConstantLowerUpper(const Graph& graph, const Node& node, float& bool ClipQuantFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& /*logger*/) const { if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Clip", {1, 6, 11, 12, 13}) || + !graph_utils::IsSupportedProvider(node, {kCpuExecutionProvider}) || !optimizer_utils::CheckOutputEdges(graph, node, 1)) { return false; } // if Clip is followed by QuantizeLinear, it can be fused into QuantizeLinear potentially const auto& next_node = *node.OutputNodesBegin(); - if (!QDQ::MatchQNode(next_node)) { + if (!graph_utils::IsSupportedProvider(next_node, {kCpuExecutionProvider}) || + !QDQ::MatchQNode(next_node)) { return false; } diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index ae263a7ca7d35..8c138b22bd52b 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -2565,7 +2565,7 @@ TEST(QDQTransformerTests, Clip) { TransformerTester(build_test_case, check_clip_graph, TransformerLevel::Default, - TransformerLevel::Level1, + TransformerLevel::Level2, opset_version, epsilon, epsilon);