From 5223eda42b90c17d9c586c45a25217898d0929e3 Mon Sep 17 00:00:00 2001 From: ZKK Date: Fri, 16 May 2025 14:16:19 +0800 Subject: [PATCH] Rewrite Exists to add scalar project --- .../sql/catalyst/optimizer/Optimizer.scala | 1 + .../catalyst/optimizer/finishAnalysis.scala | 24 +++++++ .../RewriteExistsAsScalarProjectSuite.scala | 63 +++++++++++++++++++ 3 files changed, 88 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/RewriteExistsAsScalarProjectSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 66c3bfb46530a..ee3c121c22885 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -318,6 +318,7 @@ abstract class Optimizer(catalogManager: CatalogManager) EliminateSQLFunctionNode, ReplaceExpressions, RewriteNonCorrelatedExists, + RewriteExistsAsScalarProject, PullOutGroupingExpressions, // Put `InsertMapSortInGroupingExpressions` after `PullOutGroupingExpressions`, // so the grouping keys can only be attribute and literal which makes diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala index 21e09f2e56d19..168013776e868 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala @@ -76,6 +76,30 @@ object RewriteNonCorrelatedExists extends Rule[LogicalPlan] { } } + +/** + * Rewrite exists subquery to add a scalar project + * WHERE EXISTS (SELECT * FROM TABLE B WHERE COL1 > 10) + * will be rewritten to + * WHERE EXISTS (SELECT 1 FROM (SELECT * FROM TABLE B WHERE COL1 > 10)) + */ +object RewriteExistsAsScalarProject extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressionsWithPruning( + _.containsPattern(EXISTS_SUBQUERY)) { + case exists: Exists if !hasSelectScalar(exists.plan) => + exists.copy(plan = Project(Seq(Alias(Literal(1), "scalar")()), exists.plan)) + } + + private def hasSelectScalar(plan: LogicalPlan): Boolean = plan match { + case Project(projectList, _) => + projectList.exists { + case Alias(Literal(_, IntegerType), _) => true + case _ => false + } + case _ => false + } +} + /** * Computes expressions in inline tables. This rule is supposed to be called at the very end * of the analysis phase, given that all the expressions need to be fully resolved/replaced diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RewriteExistsAsScalarProjectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RewriteExistsAsScalarProjectSuite.scala new file mode 100644 index 0000000000000..02788bfb14120 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/RewriteExistsAsScalarProjectSuite.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession + +class RewriteExistsAsScalarProjectSuite extends QueryTest with SharedSparkSession { + test("SPARK-50873: Prune column after RewriteSubquery for DSV2") { + import org.apache.spark.sql.functions._ + withTempPath { dir => + spark.range(10) + .withColumn("userid", col("id") + 1) + .withColumn("price", col("id") + 2) + .write + .mode("overwrite") + .parquet(dir.getCanonicalPath + "/sales") + spark.range(5) + .withColumn("age", col("id") + 1) + .withColumn("address", col("id") + 2) + .write.mode("overwrite").parquet(dir.getCanonicalPath + "/customer") + withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key-> "") { + spark.read.parquet(dir.getCanonicalPath + "/sales").createOrReplaceTempView("sales") + spark.read.parquet(dir.getCanonicalPath + "/customer").createOrReplaceTempView("customer") + + val df = sql( + """ + |select * from sales + |where exists (select * from customer where sales.userid == customer.id) + |""".stripMargin) + + withClue(df.queryExecution) { + val plan = df.queryExecution.optimizedPlan + val allRelationV2 = plan.collectWithSubqueries { case b: DataSourceV2ScanRelation => b } + val customerRelation = allRelationV2.find(_.relation.name.endsWith("customer")) + assert(customerRelation.isDefined, "Customer relation not found in the plan") + + val columns = customerRelation.get.output + val columnNames = columns.map(_.name) + assert(columnNames == Seq("id"), + s"Expected only 'id' column in customer relation, " + + s"but found: ${columnNames.mkString(", ")}") + } + } + } + } +}