Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit ea86140

Browse files
committedMay 16, 2025·
introduce join pushdown for dsv2
1 parent 5b07e52 commit ea86140

File tree

19 files changed

+840
-18
lines changed

19 files changed

+840
-18
lines changed
 

‎common/utils/src/main/resources/error/error-conditions.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9403,6 +9403,11 @@
94039403
"The number of fields (<numFields>) in the partition identifier is not equal to the partition schema length (<schemaLen>). The identifier might not refer to one partition."
94049404
]
94059405
},
9406+
"_LEGACY_ERROR_TEMP_3209": {
9407+
"message" : [
9408+
"Unexpected join type: <joinType>"
9409+
]
9410+
},
94069411
"_LEGACY_ERROR_TEMP_3215" : {
94079412
"message" : [
94089413
"Expected a Boolean type expression in replaceNullWithFalse, but got the type <dataType> in <expr>."
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.connector.join;
19+
20+
import org.apache.spark.annotation.Evolving;
21+
22+
/**
23+
* Base class of the public Join type API.
24+
*
25+
* @since 4.0.0
26+
*/
27+
@Evolving
28+
public final class Inner implements JoinType { }
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.connector.join;
19+
20+
import org.apache.spark.annotation.Evolving;
21+
import org.apache.spark.sql.connector.expressions.Expression;
22+
import org.apache.spark.sql.connector.expressions.NamedReference;
23+
24+
/**
25+
* Represents a column reference used in DSv2 Join pushdown.
26+
*
27+
* @since 4.0.0
28+
*/
29+
@Evolving
30+
final public class JoinColumn implements NamedReference {
31+
public JoinColumn(String[] qualifier, String name, Boolean isInLeftSideOfJoin) {
32+
this.qualifier = qualifier;
33+
this.name = name;
34+
this.isInLeftSideOfJoin = isInLeftSideOfJoin;
35+
}
36+
37+
public String[] qualifier;
38+
public String name;
39+
public Boolean isInLeftSideOfJoin;
40+
41+
@Override
42+
public String[] fieldNames() {
43+
String[] fullyQualified = new String[qualifier.length + 1];
44+
System.arraycopy(qualifier, 0, fullyQualified, 0, qualifier.length);
45+
fullyQualified[qualifier.length] = name;
46+
return qualifier;
47+
}
48+
49+
@Override
50+
public Expression[] children() { return EMPTY_EXPRESSION; }
51+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.connector.join;
19+
20+
import org.apache.spark.annotation.Evolving;
21+
22+
/**
23+
* Base class of the public Join type API.
24+
*
25+
* @since 4.0.0
26+
*/
27+
@Evolving
28+
public interface JoinType { }
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.connector.read;
19+
20+
import java.util.Optional;
21+
22+
import org.apache.spark.annotation.Evolving;
23+
import org.apache.spark.sql.connector.expressions.filter.Predicate;
24+
import org.apache.spark.sql.connector.join.JoinType;
25+
import org.apache.spark.sql.types.StructType;
26+
27+
/**
28+
* A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to
29+
* push down join operators.
30+
*
31+
* @since 4.0.0
32+
*/
33+
@Evolving
34+
public interface SupportsPushDownJoin extends ScanBuilder {
35+
boolean isRightSideCompatibleForJoin(SupportsPushDownJoin other);
36+
37+
boolean pushJoin(
38+
SupportsPushDownJoin other,
39+
JoinType joinType,
40+
Optional<Predicate> condition,
41+
StructType leftRequiredSchema,
42+
StructType rightRequiredSchema
43+
);
44+
}
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.connector.util;
19+
20+
import org.apache.spark.SparkIllegalArgumentException;
21+
import org.apache.spark.sql.connector.join.*;
22+
23+
import java.util.HashMap;
24+
import java.util.Map;
25+
26+
/**
27+
* The builder to generate SQL for specific Join type.
28+
*
29+
* @since 4.0.0
30+
*/
31+
public class JoinTypeSQLBuilder {
32+
public String build(JoinType joinType) {
33+
if (joinType instanceof Inner inner) {
34+
return visitInnerJoin(inner);
35+
} else {
36+
return visitUnexpectedJoinType(joinType);
37+
}
38+
}
39+
40+
protected String visitInnerJoin(Inner inner) {
41+
return "INNER JOIN";
42+
}
43+
44+
protected String visitUnexpectedJoinType(JoinType joinType) throws IllegalArgumentException {
45+
Map<String, String> params = new HashMap<>();
46+
params.put("joinType", String.valueOf(joinType));
47+
throw new SparkIllegalArgumentException("_LEGACY_ERROR_TEMP_3209", params);
48+
}
49+
}

‎sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.connector.util;
1919

2020
import java.util.ArrayList;
21+
import java.util.Arrays;
2122
import java.util.List;
2223
import java.util.Map;
2324
import java.util.StringJoiner;
@@ -42,6 +43,7 @@
4243
import org.apache.spark.sql.connector.expressions.aggregate.GeneralAggregateFunc;
4344
import org.apache.spark.sql.connector.expressions.aggregate.Sum;
4445
import org.apache.spark.sql.connector.expressions.aggregate.UserDefinedAggregateFunc;
46+
import org.apache.spark.sql.connector.join.JoinColumn;
4547
import org.apache.spark.sql.types.DataType;
4648

4749
/**
@@ -75,6 +77,8 @@ protected String escapeSpecialCharsForLikePattern(String str) {
7577
public String build(Expression expr) {
7678
if (expr instanceof Literal literal) {
7779
return visitLiteral(literal);
80+
} else if (expr instanceof JoinColumn column) {
81+
return visitJoinColumn(column);
7882
} else if (expr instanceof NamedReference namedReference) {
7983
return visitNamedReference(namedReference);
8084
} else if (expr instanceof Cast cast) {
@@ -174,6 +178,12 @@ protected String visitNamedReference(NamedReference namedRef) {
174178
return namedRef.toString();
175179
}
176180

181+
protected String visitJoinColumn(JoinColumn column) {
182+
List<String> fullyQualifiedName = new ArrayList<>(Arrays.asList(column.qualifier));
183+
fullyQualifiedName.add(column.name);
184+
return joinListToString(fullyQualifiedName, ".", "", "");
185+
}
186+
177187
protected String visitIn(String v, List<String> list) {
178188
if (list.isEmpty()) {
179189
return "CASE WHEN " + v + " IS NULL THEN NULL ELSE FALSE END";

‎sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,14 @@ case class AttributeReference(
394394
}
395395
}
396396

397+
case class JoinColumnReference(
398+
originalReference: AttributeReference,
399+
isReferringColumnFromLeftSubquery: Boolean = true)
400+
extends LeafExpression with Unevaluable {
401+
override def nullable: Boolean = originalReference.nullable
402+
override def dataType: DataType = originalReference.dataType
403+
}
404+
397405
/**
398406
* A place holder used when printing expressions without debugging information such as the
399407
* expression id or the unresolved indicator.

‎sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import org.apache.spark.sql.connector.catalog.functions.ScalarFunction
2626
import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, Extract => V2Extract, FieldReference, GeneralScalarExpression, LiteralValue, NullOrdering, SortDirection, SortValue, UserDefinedScalarFunc}
2727
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum, UserDefinedAggregateFunc}
2828
import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate}
29+
import org.apache.spark.sql.connector.join.JoinColumn
2930
import org.apache.spark.sql.internal.SQLConf
3031
import org.apache.spark.sql.types.{BooleanType, DataType, IntegerType, StringType}
3132

@@ -79,6 +80,12 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) extends L
7980
case Literal(true, BooleanType) => Some(new AlwaysTrue())
8081
case Literal(false, BooleanType) => Some(new AlwaysFalse())
8182
case Literal(value, dataType) => Some(LiteralValue(value, dataType))
83+
case joinRefColumn: JoinColumnReference =>
84+
Some (new JoinColumn(
85+
Array(),
86+
joinRefColumn.originalReference.name,
87+
joinRefColumn.isReferringColumnFromLeftSubquery
88+
))
8289
case col @ ColumnOrField(nameParts) =>
8390
val ref = FieldReference(nameParts)
8491
if (isPredicate && col.dataType.isInstanceOf[BooleanType]) {

‎sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1640,6 +1640,14 @@ object SQLConf {
16401640
.booleanConf
16411641
.createWithDefault(!Utils.isTesting)
16421642

1643+
val DATA_SOURCE_V2_JOIN_PUSHDOWN =
1644+
buildConf("spark.sql.optimizer.datasourceV2JoinPushdown")
1645+
.internal()
1646+
.doc("When this config is set to true, join is tried to be pushed down" +
1647+
"for DSv2 data sources in V2ScanRelationPushdown optimization rule.")
1648+
.booleanConf
1649+
.createWithDefault(false)
1650+
16431651
// This is used to set the default data source
16441652
val DEFAULT_DATA_SOURCE_NAME = buildConf("spark.sql.sources.default")
16451653
.doc("The default data source to use in input/output.")
@@ -5988,6 +5996,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
59885996

59895997
def expressionTreeChangeLogLevel: Level = getConf(EXPRESSION_TREE_CHANGE_LOG_LEVEL)
59905998

5999+
def dataSourceV2JoinPushdown: Boolean = getConf(DATA_SOURCE_V2_JOIN_PUSHDOWN)
6000+
59916001
def dynamicPartitionPruningEnabled: Boolean = getConf(DYNAMIC_PARTITION_PRUNING_ENABLED)
59926002

59936003
def dynamicPartitionPruningUseStats: Boolean = getConf(DYNAMIC_PARTITION_PRUNING_USE_STATS)

‎sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,12 @@ case class RowDataSourceScanExec(
189189
seqToString(markedFilters.toSeq)
190190
}
191191

192+
val pushedJoins = if (pushedDownOperators.pushedJoins.nonEmpty) {
193+
Map("PushedJoins" -> seqToString(pushedDownOperators.pushedJoins))
194+
} else {
195+
Map()
196+
}
197+
192198
Map("ReadSchema" -> requiredSchema.catalogString,
193199
"PushedFilters" -> pushedFilters) ++
194200
pushedDownOperators.aggregation.fold(Map[String, String]()) { v =>
@@ -200,7 +206,8 @@ case class RowDataSourceScanExec(
200206
offsetInfo ++
201207
pushedDownOperators.sample.map(v => "PushedSample" ->
202208
s"SAMPLE (${(v.upperBound - v.lowerBound) * 100}) ${v.withReplacement} SEED(${v.seed})"
203-
)
209+
) ++
210+
pushedJoins
204211
}
205212

206213
// Don't care about `rdd` and `tableIdentifier`, and `stream` when canonicalizing.

‎sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
3737
import org.apache.spark.sql.catalyst.expressions._
3838
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
3939
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
40+
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
4041
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, InsertIntoDir, InsertIntoStatement, LogicalPlan, Project}
4142
import org.apache.spark.sql.catalyst.rules.Rule
4243
import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2
@@ -47,6 +48,7 @@ import org.apache.spark.sql.connector.catalog.{SupportsRead, V1Table}
4748
import org.apache.spark.sql.connector.catalog.TableCapability._
4849
import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, NullOrdering, SortDirection, SortOrder => V2SortOrder, SortValue}
4950
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation}
51+
import org.apache.spark.sql.connector.join.{Inner => V2Inner, JoinType => V2JoinType}
5052
import org.apache.spark.sql.errors.QueryCompilationErrors
5153
import org.apache.spark.sql.execution
5254
import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan}
@@ -501,6 +503,13 @@ object DataSourceStrategy
501503
}
502504
}
503505

506+
def translateJoinType(joinType: JoinType): Option[V2JoinType] = {
507+
joinType match {
508+
case Inner => Some(new V2Inner)
509+
case _ => None
510+
}
511+
}
512+
504513
/**
505514
* Convert RDD of Row into RDD of InternalRow with objects in catalyst types
506515
*/

‎sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ class JDBCOptions(
104104
}
105105
}
106106

107+
var containsJoinInQuery: Boolean = false
108+
107109
// ------------------------------------------------------------
108110
// Optional parameters
109111
// ------------------------------------------------------------
@@ -215,6 +217,10 @@ class JDBCOptions(
215217
// This only applies to Data Source V2 JDBC
216218
val pushDownTableSample = parameters.getOrElse(JDBC_PUSHDOWN_TABLESAMPLE, "true").toBoolean
217219

220+
// An option to allow/disallow pushing down JOIN into JDBC data source
221+
// This only applies to Data Source V2 JDBC
222+
val pushDownJoin = parameters.getOrElse(JDBC_PUSHDOWN_JOIN, "true").toBoolean
223+
218224
// The local path of user's keytab file, which is assumed to be pre-uploaded to all nodes either
219225
// by --files option of spark-submit or manually
220226
val keytab = {
@@ -321,6 +327,7 @@ object JDBCOptions {
321327
val JDBC_PUSHDOWN_LIMIT = newOption("pushDownLimit")
322328
val JDBC_PUSHDOWN_OFFSET = newOption("pushDownOffset")
323329
val JDBC_PUSHDOWN_TABLESAMPLE = newOption("pushDownTableSample")
330+
val JDBC_PUSHDOWN_JOIN = newOption("pushDownJoin")
324331
val JDBC_KEYTAB = newOption("keytab")
325332
val JDBC_PRINCIPAL = newOption("principal")
326333
val JDBC_TABLE_COMMENT = newOption("tableComment")

‎sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ case class PushedDownOperators(
3030
limit: Option[Int],
3131
offset: Option[Int],
3232
sortValues: Seq[SortOrder],
33-
pushedPredicates: Seq[Predicate]) {
33+
pushedPredicates: Seq[Predicate],
34+
pushedJoins: Seq[String] = Seq()) {
3435
assert((limit.isEmpty && sortValues.isEmpty) || limit.isDefined)
3536
}

‎sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala

Lines changed: 147 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,21 @@
1818
package org.apache.spark.sql.execution.datasources.v2
1919

2020
import scala.collection.mutable
21+
import scala.jdk.OptionConverters._
2122

2223
import org.apache.spark.internal.LogKeys.{AGGREGATE_FUNCTIONS, GROUP_BY_EXPRS, POST_SCAN_FILTERS, PUSHED_FILTERS, RELATION_NAME, RELATION_OUTPUT}
2324
import org.apache.spark.internal.MDC
24-
import org.apache.spark.sql.catalyst.expressions.{aggregate, Alias, And, Attribute, AttributeMap, AttributeReference, AttributeSet, Cast, Expression, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression}
25+
import org.apache.spark.sql.catalyst.expressions.{aggregate, Alias, And, Attribute, AttributeMap, AttributeReference, AttributeSet, Cast, Expression, ExprId, IntegerLiteral, JoinColumnReference, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression}
2526
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
2627
import org.apache.spark.sql.catalyst.optimizer.CollapseProject
2728
import org.apache.spark.sql.catalyst.planning.{PhysicalOperation, ScanOperation}
28-
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LimitAndOffset, LocalLimit, LogicalPlan, Offset, OffsetAndLimit, Project, Sample, Sort}
29+
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Join, LeafNode, Limit, LimitAndOffset, LocalLimit, LogicalPlan, Offset, OffsetAndLimit, Project, Sample, Sort}
2930
import org.apache.spark.sql.catalyst.rules.Rule
30-
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
31+
import org.apache.spark.sql.catalyst.types.DataTypeUtils.{fromAttributes, toAttributes}
3132
import org.apache.spark.sql.connector.expressions.{SortOrder => V2SortOrder}
3233
import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Avg, Count, CountStar, Max, Min, Sum}
3334
import org.apache.spark.sql.connector.expressions.filter.Predicate
34-
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan}
35+
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownJoin, V1Scan}
3536
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
3637
import org.apache.spark.sql.sources
3738
import org.apache.spark.sql.types.{DataType, DecimalType, IntegerType, StructType}
@@ -46,9 +47,11 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
4647
createScanBuilder,
4748
pushDownSample,
4849
pushDownFilters,
50+
pushDownJoin,
4951
pushDownAggregates,
5052
pushDownLimitAndOffset,
5153
buildScanWithPushedAggregate,
54+
buildScanWithPushedJoin,
5255
pruneColumns)
5356

5457
pushdownRules.foldLeft(plan) { (newPlan, pushDownRule) =>
@@ -58,7 +61,14 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
5861

5962
private def createScanBuilder(plan: LogicalPlan) = plan.transform {
6063
case r: DataSourceV2Relation =>
61-
ScanBuilderHolder(r.output, r, r.table.asReadable.newScanBuilder(r.options))
64+
val sHolder = ScanBuilderHolder(r.output, r, r.table.asReadable.newScanBuilder(r.options))
65+
sHolder.output.foreach{ e =>
66+
// join column names change when joins are pushed down. At the end, we need to keep
67+
// original names of the plan, so we are storing original names for each of the exprIDs.
68+
sHolder.exprIdToOriginalName.put(e.exprId, e.name)
69+
}
70+
71+
sHolder
6272
}
6373

6474
private def pushDownFilters(plan: LogicalPlan) = plan.transform {
@@ -98,6 +108,89 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
98108
filterCondition.map(Filter(_, sHolder)).getOrElse(sHolder)
99109
}
100110

111+
def pushDownJoin(plan: LogicalPlan): LogicalPlan = plan.transformUp {
112+
// Join can be attempted to be pushed down only if left and right side of join are
113+
// compatible (same data source, for example). Also, another requirement is that if
114+
// there are projections between Join and ScanBuilderHolder, these projec
115+
case node @ Join(
116+
PhysicalOperation(
117+
leftProjections,
118+
Nil,
119+
leftHolder @ ScanBuilderHolder(_, _, lBuilder: SupportsPushDownJoin)
120+
),
121+
PhysicalOperation(
122+
rightProjections,
123+
Nil,
124+
rightHolder @ ScanBuilderHolder(_, _, rBuilder: SupportsPushDownJoin)
125+
),
126+
joinType,
127+
condition,
128+
_) if conf.dataSourceV2JoinPushdown &&
129+
// TODO: I think projections will always be Seq[AttributeReference] because
130+
// When
131+
// SELECT tbl1.col+2, tbl2.* FROM tbl1 JOIN tlb2
132+
// is executed, col is pruned down, but col + 2 will be projected on top of join.
133+
leftProjections.forall(_.isInstanceOf[AttributeReference]) &&
134+
rightProjections.forall(_.isInstanceOf[AttributeReference]) &&
135+
lBuilder.isRightSideCompatibleForJoin(rBuilder) =>
136+
val normalizedLeftProjections = DataSourceStrategy.normalizeExprs(
137+
leftProjections,
138+
leftHolder.output
139+
).asInstanceOf[Seq[AttributeReference]]
140+
val leftRequiredSchema = fromAttributes(normalizedLeftProjections)
141+
142+
val normalizedRightProjections = DataSourceStrategy.normalizeExprs(
143+
rightProjections,
144+
rightHolder.output
145+
).asInstanceOf[Seq[AttributeReference]]
146+
val rightRequiredSchema = fromAttributes(normalizedRightProjections)
147+
148+
val normalizedCondition = condition.map { e =>
149+
DataSourceStrategy.normalizeExprs(
150+
Seq(e),
151+
leftHolder.output ++ rightHolder.output
152+
).head
153+
}
154+
155+
val conditionWithJoinColumns = normalizedCondition.map { cond =>
156+
cond.transformUp {
157+
case a: AttributeReference =>
158+
val isInLeftSide = leftProjections.filter(_.exprId == a.exprId).nonEmpty
159+
JoinColumnReference(a, isInLeftSide)
160+
}
161+
}
162+
163+
val translatedCondition =
164+
conditionWithJoinColumns.flatMap(DataSourceV2Strategy.translateFilterV2(_))
165+
val translatedJoinType = DataSourceStrategy.translateJoinType(joinType)
166+
167+
if (translatedCondition.isDefined == condition.isDefined &&
168+
translatedJoinType.isDefined &&
169+
lBuilder.pushJoin(
170+
rBuilder,
171+
translatedJoinType.get,
172+
translatedCondition.toJava,
173+
leftRequiredSchema,
174+
rightRequiredSchema
175+
)) {
176+
leftHolder.joinedRelations = leftHolder.joinedRelations :+ rightHolder.relation
177+
178+
val newSchema = leftHolder.builder.build().readSchema()
179+
val newOutput = (leftProjections ++ rightProjections).asInstanceOf[Seq[AttributeReference]]
180+
.zip(newSchema.fields)
181+
.map { case (attr, schemaField) =>
182+
attr.withName(schemaField.name)
183+
}
184+
185+
leftHolder.exprIdToOriginalName ++= rightHolder.exprIdToOriginalName
186+
leftHolder.output = newOutput
187+
leftHolder.isJoinPushed = true
188+
leftHolder
189+
} else {
190+
node
191+
}
192+
}
193+
101194
def pushDownAggregates(plan: LogicalPlan): LogicalPlan = plan.transform {
102195
// update the scan builder with agg pushdown and return a new plan with agg pushed
103196
case agg: Aggregate => rewriteAggregate(agg)
@@ -113,10 +206,19 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
113206

114207
val aggExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int]
115208
val aggregates = collectAggregates(actualResultExprs, aggExprToOutputOrdinal)
116-
val normalizedAggExprs = DataSourceStrategy.normalizeExprs(
117-
aggregates, holder.relation.output).asInstanceOf[Seq[AggregateExpression]]
118-
val normalizedGroupingExpr = DataSourceStrategy.normalizeExprs(
119-
actualGroupExprs, holder.relation.output)
209+
val normalizedAggExprs = if (holder.isJoinPushed) {
210+
DataSourceStrategy.normalizeExprs(aggregates, holder.output)
211+
.asInstanceOf[Seq[AggregateExpression]]
212+
} else {
213+
DataSourceStrategy.normalizeExprs(aggregates, holder.relation.output)
214+
.asInstanceOf[Seq[AggregateExpression]]
215+
}
216+
val normalizedGroupingExpr =
217+
if (holder.isJoinPushed) {
218+
DataSourceStrategy.normalizeExprs(actualGroupExprs, holder.output)
219+
} else {
220+
DataSourceStrategy.normalizeExprs(actualGroupExprs, holder.relation.output)
221+
}
120222
val translatedAggOpt = DataSourceStrategy.translateAggregation(
121223
normalizedAggExprs, normalizedGroupingExpr)
122224
if (translatedAggOpt.isEmpty) {
@@ -356,6 +458,26 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
356458
Project(projectList, scanRelation)
357459
}
358460

461+
def buildScanWithPushedJoin(plan: LogicalPlan): LogicalPlan = plan.transform {
462+
case holder: ScanBuilderHolder if holder.isJoinPushed && !holder.isStreaming =>
463+
val scan = holder.builder.build()
464+
val realOutput = toAttributes(scan.readSchema())
465+
assert(realOutput.length == holder.output.length,
466+
"The data source returns unexpected number of columns")
467+
val wrappedScan = getWrappedScan(scan, holder)
468+
val scanRelation = DataSourceV2ScanRelation(holder.relation, wrappedScan, realOutput)
469+
470+
// When join is pushed down, the output of ScanBuilderHolder is going to be, for example,
471+
// subquery_2_col_0#0, subquery_2_col_1#1, subquery_2_col_2#2.
472+
// We should revert these names back to original names. For example,
473+
// SALARY#0, NAME#1, DEPT#1. This is done by adding projection with appropriate aliases.
474+
val projectList = realOutput.zip(holder.output).map { case (a1, a2) =>
475+
val originalName = holder.exprIdToOriginalName(a2.exprId)
476+
Alias(a1, originalName)(a2.exprId)
477+
}
478+
Project(projectList, scanRelation)
479+
}
480+
359481
def pruneColumns(plan: LogicalPlan): LogicalPlan = plan.transform {
360482
case ScanOperation(project, filtersStayUp, filtersPushDown, sHolder: ScanBuilderHolder) =>
361483
// column pruning
@@ -441,8 +563,13 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
441563
} else {
442564
aliasReplacedOrder.asInstanceOf[Seq[SortOrder]]
443565
}
444-
val normalizedOrders = DataSourceStrategy.normalizeExprs(
445-
newOrder, sHolder.relation.output).asInstanceOf[Seq[SortOrder]]
566+
val normalizedOrders = if (sHolder.isJoinPushed) {
567+
DataSourceStrategy.normalizeExprs(
568+
newOrder, sHolder.output).asInstanceOf[Seq[SortOrder]]
569+
} else {
570+
DataSourceStrategy.normalizeExprs(
571+
newOrder, sHolder.relation.output).asInstanceOf[Seq[SortOrder]]
572+
}
446573
val orders = DataSourceStrategy.translateSortOrders(normalizedOrders)
447574
if (orders.length == order.length) {
448575
val (isPushed, isPartiallyPushed) =
@@ -549,7 +676,8 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
549676
case _ => Array.empty[sources.Filter]
550677
}
551678
val pushedDownOperators = PushedDownOperators(sHolder.pushedAggregate, sHolder.pushedSample,
552-
sHolder.pushedLimit, sHolder.pushedOffset, sHolder.sortOrders, sHolder.pushedPredicates)
679+
sHolder.pushedLimit, sHolder.pushedOffset, sHolder.sortOrders, sHolder.pushedPredicates,
680+
sHolder.joinedRelations.map(_.name))
553681
V1ScanWrapper(v1, pushedFilters.toImmutableArraySeq, pushedDownOperators)
554682
case _ => scan
555683
}
@@ -573,6 +701,13 @@ case class ScanBuilderHolder(
573701
var pushedAggregate: Option[Aggregation] = None
574702

575703
var pushedAggOutputMap: AttributeMap[Expression] = AttributeMap.empty[Expression]
704+
705+
var joinedRelations: Seq[DataSourceV2RelationBase] = Seq()
706+
707+
var isJoinPushed: Boolean = false
708+
709+
var exprIdToOriginalName: scala.collection.mutable.Map[ExprId, String] =
710+
scala.collection.mutable.Map.empty[ExprId, String]
576711
}
577712

578713
// A wrapper for v1 scan to carry the translated filters and the handled ones, along with

‎sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala

Lines changed: 141 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,18 @@
1616
*/
1717
package org.apache.spark.sql.execution.datasources.v2.jdbc
1818

19+
import java.util.Optional
20+
21+
import scala.jdk.OptionConverters._
1922
import scala.util.control.NonFatal
2023

2124
import org.apache.spark.internal.Logging
2225
import org.apache.spark.sql.SparkSession
2326
import org.apache.spark.sql.connector.expressions.{FieldReference, SortOrder}
2427
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
2528
import org.apache.spark.sql.connector.expressions.filter.Predicate
26-
import org.apache.spark.sql.connector.read.{ScanBuilder, SupportsPushDownAggregates, SupportsPushDownLimit, SupportsPushDownOffset, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters}
29+
import org.apache.spark.sql.connector.join.{JoinColumn, JoinType}
30+
import org.apache.spark.sql.connector.read.{ScanBuilder, SupportsPushDownAggregates, SupportsPushDownJoin, SupportsPushDownLimit, SupportsPushDownOffset, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters}
2731
import org.apache.spark.sql.execution.datasources.PartitioningUtils
2832
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation}
2933
import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
@@ -33,7 +37,7 @@ import org.apache.spark.sql.types.StructType
3337
case class JDBCScanBuilder(
3438
session: SparkSession,
3539
schema: StructType,
36-
jdbcOptions: JDBCOptions)
40+
var jdbcOptions: JDBCOptions)
3741
extends ScanBuilder
3842
with SupportsPushDownV2Filters
3943
with SupportsPushDownRequiredColumns
@@ -42,6 +46,7 @@ case class JDBCScanBuilder(
4246
with SupportsPushDownOffset
4347
with SupportsPushDownTableSample
4448
with SupportsPushDownTopN
49+
with SupportsPushDownJoin
4550
with Logging {
4651

4752
private val dialect = JdbcDialects.get(jdbcOptions.url)
@@ -121,6 +126,128 @@ case class JDBCScanBuilder(
121126
}
122127
}
123128

129+
override def isRightSideCompatibleForJoin(other: SupportsPushDownJoin): Boolean = {
130+
other.isInstanceOf[JDBCScanBuilder] &&
131+
jdbcOptions.url == other.asInstanceOf[JDBCScanBuilder].jdbcOptions.url
132+
};
133+
134+
override def pushJoin(
135+
other: SupportsPushDownJoin,
136+
joinType: JoinType,
137+
condition: Optional[Predicate],
138+
leftRequiredSchema: StructType,
139+
rightRequiredSchema: StructType
140+
): Boolean = {
141+
if (!jdbcOptions.pushDownJoin || !dialect.supportsJoin) return false
142+
143+
val leftNodeSQLQuery = buildSQLQuery()
144+
val rightNodeSQLQuery = other.asInstanceOf[JDBCScanBuilder].buildSQLQuery()
145+
146+
val leftSideQualifier = JoinOutputAliasIterator.get
147+
val rightSideQualifier = JoinOutputAliasIterator.get
148+
149+
val leftProjections: Seq[JoinColumn] = leftRequiredSchema.fields.map { e =>
150+
new JoinColumn(Array(leftSideQualifier), e.name, true)
151+
}.toSeq
152+
val rightProjections: Seq[JoinColumn] = rightRequiredSchema.fields.map { e =>
153+
new JoinColumn(Array(rightSideQualifier), e.name, false)
154+
}.toSeq
155+
156+
var aliasedLeftSchema = StructType(Seq())
157+
var aliasedRightSchema = StructType(Seq())
158+
val outputAliasPrefix = JoinOutputAliasIterator.get
159+
160+
val aliasedOutput = (leftProjections ++ rightProjections)
161+
.zipWithIndex
162+
.map { case (proj, i) =>
163+
val name = s"${outputAliasPrefix}_col_$i"
164+
val output = FieldReference(name)
165+
if (i < leftProjections.length) {
166+
val field = leftRequiredSchema.fields(i)
167+
aliasedLeftSchema =
168+
aliasedLeftSchema.add(name, field.dataType, field.nullable, field.metadata)
169+
} else {
170+
val field = rightRequiredSchema.fields(i - leftRequiredSchema.fields.length)
171+
aliasedRightSchema =
172+
aliasedRightSchema.add(name, field.dataType, field.nullable, field.metadata)
173+
}
174+
175+
s"""${dialect.compileExpression(proj).get} AS ${dialect.compileExpression(output).get}"""
176+
}.mkString(",")
177+
178+
val compiledJoinType = dialect.compileJoinType(joinType)
179+
if (!compiledJoinType.isDefined) return false
180+
181+
val conditionString = condition.toScala match {
182+
case Some(cond) =>
183+
qualifyCondition(cond, leftSideQualifier, rightSideQualifier)
184+
s"ON ${dialect.compileExpression(cond).get}"
185+
case _ => ""
186+
}
187+
188+
val subqueryASKeyword = if (dialect.needsASKeywordForJoinSubquery) {
189+
" AS "
190+
} else {
191+
""
192+
}
193+
194+
val compiledLeftSideQualifier =
195+
dialect.compileExpression(FieldReference(leftSideQualifier)).get
196+
val compiledRightSideQualifier =
197+
dialect.compileExpression(FieldReference(rightSideQualifier)).get
198+
199+
val joinQuery =
200+
s"""
201+
|SELECT $aliasedOutput FROM
202+
|($leftNodeSQLQuery)$subqueryASKeyword$compiledLeftSideQualifier
203+
|${compiledJoinType.get}
204+
|($rightNodeSQLQuery)$subqueryASKeyword$compiledRightSideQualifier
205+
|$conditionString
206+
|""".stripMargin
207+
208+
val newMap = jdbcOptions.parameters.originalMap +
209+
(JDBCOptions.JDBC_QUERY_STRING -> joinQuery) - (JDBCOptions.JDBC_TABLE_NAME)
210+
211+
jdbcOptions = new JDBCOptions(newMap)
212+
jdbcOptions.containsJoinInQuery = true
213+
214+
// We can merge schemas since there are no fields with duplicate names
215+
finalSchema = aliasedLeftSchema.merge(aliasedRightSchema)
216+
pushedPredicate = Array.empty[Predicate]
217+
pushedAggregateList = Array()
218+
pushedGroupBys = None
219+
tableSample = None
220+
pushedLimit = 0
221+
sortOrders = Array.empty[String]
222+
pushedOffset = 0
223+
224+
true
225+
}
226+
227+
def buildSQLQuery(): String = {
228+
build()
229+
.toV1TableScan(session.sqlContext).asInstanceOf[JDBCV1RelationFromV2Scan]
230+
.buildScan().asInstanceOf[JDBCRDD]
231+
.getExternalEngineQuery
232+
}
233+
234+
// Fully qualify the condition. For example:
235+
// DEPT=SALARY turns into leftSideQualifier.DEPT = rightSideQualifier=SALARY
236+
def qualifyCondition(condition: Predicate, leftSideQualifier: String, rightSideQualifier: String)
237+
: Unit = {
238+
condition.references()
239+
.filter(_.isInstanceOf[JoinColumn])
240+
.foreach { e =>
241+
val qualifier = if (e.asInstanceOf[JoinColumn].isInLeftSideOfJoin) {
242+
leftSideQualifier
243+
} else {
244+
rightSideQualifier
245+
}
246+
247+
e.asInstanceOf[JoinColumn].qualifier = Array(qualifier)
248+
}
249+
}
250+
124251
override def pushTableSample(
125252
lowerBound: Double,
126253
upperBound: Double,
@@ -195,3 +322,15 @@ case class JDBCScanBuilder(
195322
pushedAggregateList, pushedGroupBys, tableSample, pushedLimit, sortOrders, pushedOffset)
196323
}
197324
}
325+
326+
object JoinOutputAliasIterator {
327+
private var curId = new java.util.concurrent.atomic.AtomicLong()
328+
329+
def get: String = {
330+
"subquery_" + curId.getAndIncrement()
331+
}
332+
333+
def reset(): Unit = {
334+
curId = new java.util.concurrent.atomic.AtomicLong()
335+
}
336+
}

‎sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,4 +298,6 @@ private[sql] case class H2Dialect() extends JdbcDialect with NoLegacyJDBCError {
298298
override def supportsLimit: Boolean = true
299299

300300
override def supportsOffset: Boolean = true
301+
302+
override def supportsJoin: Boolean = true
301303
}

‎sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ import org.apache.spark.sql.connector.catalog.index.TableIndex
4343
import org.apache.spark.sql.connector.expressions.{Expression, Literal, NamedReference}
4444
import org.apache.spark.sql.connector.expressions.aggregate.AggregateFunc
4545
import org.apache.spark.sql.connector.expressions.filter.Predicate
46-
import org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder
46+
import org.apache.spark.sql.connector.join.{JoinColumn, JoinType}
47+
import org.apache.spark.sql.connector.util.{JoinTypeSQLBuilder, V2ExpressionSQLBuilder}
4748
import org.apache.spark.sql.errors.QueryCompilationErrors
4849
import org.apache.spark.sql.execution.datasources.jdbc.{DriverRegistry, JDBCOptions, JdbcOptionsInWrite, JdbcUtils}
4950
import org.apache.spark.sql.execution.datasources.jdbc.connection.ConnectionProvider
@@ -405,6 +406,10 @@ abstract class JdbcDialect extends Serializable with Logging {
405406
quoteIdentifier(namedRef.fieldNames.head)
406407
}
407408

409+
override def visitJoinColumn(column: JoinColumn): String = {
410+
(column.qualifier.toSeq ++ Seq(column.name)).map(quoteIdentifier(_)).mkString(".")
411+
}
412+
408413
override def visitCast(expr: String, exprDataType: DataType, dataType: DataType): String = {
409414
val databaseTypeDefinition =
410415
getJDBCType(dataType).map(_.databaseTypeDefinition).getOrElse(dataType.typeName)
@@ -491,6 +496,8 @@ abstract class JdbcDialect extends Serializable with Logging {
491496
}
492497
}
493498

499+
private[jdbc] class JDBCJoinTypeSQLBuilder extends JoinTypeSQLBuilder {}
500+
494501
/**
495502
* Returns whether the database supports function.
496503
* @param funcName Upper-cased function name
@@ -516,6 +523,18 @@ abstract class JdbcDialect extends Serializable with Logging {
516523
}
517524
}
518525

526+
@Since("4.0.0")
527+
def compileJoinType(joinType: JoinType): Option[String] = {
528+
val joinTypeBuilder = new JDBCJoinTypeSQLBuilder()
529+
try {
530+
Some(joinTypeBuilder.build(joinType))
531+
} catch {
532+
case NonFatal(e) =>
533+
logWarning("Error occurs while compiling join type ", e)
534+
None
535+
}
536+
}
537+
519538
/**
520539
* Converts aggregate function to String representing a SQL expression.
521540
* @param aggFunction The aggregate function to be converted.
@@ -837,6 +856,21 @@ abstract class JdbcDialect extends Serializable with Logging {
837856

838857
def supportsHint: Boolean = false
839858

859+
/**
860+
* Returns true if dialect supports JOIN operator.
861+
*/
862+
def supportsJoin: Boolean = false
863+
864+
/**
865+
* If true, left/right subquery of JOIN needs to have AS keywords before alias.
866+
* For example,
867+
* SELECT * FROM (subquery1) AS alias1 JOIN ...
868+
*
869+
* If false, SQL query wouldn't have AS keyword, so the query would look like
870+
* SELECT * FROM (subquery1) alias1 JOIN ...
871+
*/
872+
def needsASKeywordForJoinSubquery: Boolean = true
873+
840874
/**
841875
* Return the DB-specific quoted and fully qualified table name
842876
*/

‎sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala

Lines changed: 249 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import org.apache.spark.{SparkConf, SparkException, SparkIllegalArgumentExceptio
2929
import org.apache.spark.sql.{AnalysisException, DataFrame, ExplainSuiteHelper, QueryTest, Row}
3030
import org.apache.spark.sql.catalyst.InternalRow
3131
import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, IndexAlreadyExistsException, NoSuchIndexException}
32-
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLimit, LocalLimit, Offset, Sort}
32+
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLimit, Join, LocalLimit, Offset, Sort}
3333
import org.apache.spark.sql.connector.{IntegralAverage, StrLen}
3434
import org.apache.spark.sql.connector.catalog.{Catalogs, Identifier, TableCatalog}
3535
import org.apache.spark.sql.connector.catalog.functions.{ScalarFunction, UnboundFunction}
@@ -141,6 +141,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
141141
.set("spark.sql.catalog.h2.pushDownAggregate", "true")
142142
.set("spark.sql.catalog.h2.pushDownLimit", "true")
143143
.set("spark.sql.catalog.h2.pushDownOffset", "true")
144+
.set("spark.sql.catalog.h2.pushDownJoin", "true")
144145

145146
private def withConnection[T](f: Connection => T): T = {
146147
val conn = DriverManager.getConnection(url, new Properties())
@@ -265,6 +266,253 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
265266
super.afterAll()
266267
}
267268

269+
test("Test 2-way join") {
270+
val rows = withSQLConf(
271+
SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") {
272+
sql("SELECT * FROM h2.test.employee a, h2.test.employee b").collect().toSeq
273+
}
274+
275+
withSQLConf(
276+
SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") {
277+
val df = sql("SELECT * FROM h2.test.employee a, h2.test.employee b")
278+
val joinNodes = df.queryExecution.optimizedPlan.collect {
279+
case j: Join => j
280+
}
281+
282+
assert(joinNodes.isEmpty)
283+
checkAnswer(df, rows)
284+
}
285+
}
286+
287+
test("Test multi way join") {
288+
val rows = withSQLConf(
289+
SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") {
290+
sql("SELECT * FROM " +
291+
"h2.test.employee a, " +
292+
"h2.test.employee b, " +
293+
"h2.test.employee c, " +
294+
"h2.test.employee d, " +
295+
"h2.test.employee e")
296+
.collect().toSeq
297+
}
298+
299+
withSQLConf(
300+
SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") {
301+
val df = sql("SELECT * FROM " +
302+
"h2.test.employee a, " +
303+
"h2.test.employee b, " +
304+
"h2.test.employee c, " +
305+
"h2.test.employee d, " +
306+
"h2.test.employee e")
307+
308+
val joinNodes = df.queryExecution.optimizedPlan.collect {
309+
case j: Join => j
310+
}
311+
312+
assert(joinNodes.isEmpty)
313+
checkPushedInfo(df,
314+
"PushedJoins: [h2.test.employee, h2.test.employee, h2.test.employee, h2.test.employee]")
315+
checkAnswer(df, rows)
316+
}
317+
}
318+
319+
test("Test join with condition") {
320+
val rows = withSQLConf(
321+
SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") {
322+
sql("SELECT * FROM h2.test.employee a join h2.test.employee b on a.dept = b.dept + 1")
323+
.collect().toSeq
324+
}
325+
326+
withSQLConf(
327+
SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") {
328+
val df = sql(
329+
"SELECT * FROM h2.test.employee a join h2.test.employee b on a.dept = b.dept + 1"
330+
)
331+
val joinNodes = df.queryExecution.optimizedPlan.collect {
332+
case j: Join => j
333+
}
334+
335+
assert(joinNodes.isEmpty)
336+
checkAnswer(df, rows)
337+
}
338+
}
339+
340+
test("Test multi-way-join with conditions") {
341+
val rows = withSQLConf(
342+
SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") {
343+
sql("SELECT * FROM " +
344+
"h2.test.employee a " +
345+
"join h2.test.employee b on b.dept = a.dept + 1 " +
346+
"join h2.test.employee c on c.dept = b.dept - 1 ")
347+
.collect().toSeq
348+
}
349+
350+
assert(!rows.isEmpty)
351+
352+
withSQLConf(
353+
SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") {
354+
val df = sql("SELECT * FROM " +
355+
"h2.test.employee a " +
356+
"join h2.test.employee b on b.dept = a.dept + 1 " +
357+
"join h2.test.employee c on c.dept = b.dept - 1 ")
358+
val joinNodes = df.queryExecution.optimizedPlan.collect {
359+
case j: Join => j
360+
}
361+
362+
assert(joinNodes.isEmpty)
363+
checkAnswer(df, rows)
364+
}
365+
}
366+
367+
test("Test join with column pruning") {
368+
val rows = withSQLConf(
369+
SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") {
370+
sql("SELECT a.dept + 2, b.dept, b.salary FROM " +
371+
"h2.test.employee a join h2.test.employee b " +
372+
"on a.dept = b.dept + 1")
373+
.collect().toSeq
374+
}
375+
376+
withSQLConf(
377+
SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") {
378+
val df = sql("" +
379+
"SELECT a.dept + 2, b.dept, b.salary FROM " +
380+
"h2.test.employee a join h2.test.employee b " +
381+
"on a.dept = b.dept + 1")
382+
383+
val joinNodes = df.queryExecution.optimizedPlan.collect {
384+
case j: Join => j
385+
}
386+
387+
assert(joinNodes.isEmpty)
388+
checkAnswer(df, rows)
389+
}
390+
}
391+
392+
test("Test multi way join with column pruning") {
393+
val rows = withSQLConf(
394+
SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") {
395+
sql("SELECT a.dept, b.*, c.dept, c.salary + a.salary FROM " +
396+
"h2.test.employee a " +
397+
"join h2.test.employee b on b.dept = a.dept + 1 " +
398+
"join h2.test.employee c on c.dept = b.dept - 1 ")
399+
.collect().toSeq
400+
}
401+
402+
withSQLConf(
403+
SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") {
404+
val df = sql("" +
405+
"SELECT a.dept, b.*, c.dept, c.salary + a.salary FROM " +
406+
"h2.test.employee a " +
407+
"join h2.test.employee b on b.dept = a.dept + 1 " +
408+
"join h2.test.employee c on c.dept = b.dept - 1 ")
409+
410+
val joinNodes = df.queryExecution.optimizedPlan.collect {
411+
case j: Join => j
412+
}
413+
414+
assert(joinNodes.isEmpty)
415+
checkAnswer(df, rows)
416+
}
417+
}
418+
419+
test("Test aggregate on top of 2 way join") {
420+
val rows = withSQLConf(
421+
SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") {
422+
sql("SELECT min(a.dept + b.dept), min(a.dept) " +
423+
"FROM h2.test.employee a " +
424+
"join h2.test.employee b on a.dept = b.dept + 1")
425+
.collect().toSeq
426+
}
427+
428+
withSQLConf(
429+
SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") {
430+
val df = sql("SELECT min(a.dept + b.dept), min(a.dept) " +
431+
"FROM h2.test.employee a " +
432+
"join h2.test.employee b on a.dept = b.dept + 1")
433+
val joinNodes = df.queryExecution.optimizedPlan.collect {
434+
case j: Join => j
435+
}
436+
437+
val aggNodes = df.queryExecution.optimizedPlan.collect {
438+
case a: Aggregate => a
439+
}
440+
441+
assert(joinNodes.isEmpty)
442+
assert(aggNodes.isEmpty)
443+
checkAnswer(df, rows)
444+
}
445+
}
446+
447+
test("Test aggregate on top of multi way join") {
448+
val rows = withSQLConf(
449+
SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") {
450+
sql("SELECT min(a.dept + b.dept), min(a.dept), min(c.dept - 2) " +
451+
"from h2.test.employee a " +
452+
"join h2.test.employee b on b.dept = a.dept + 1 " +
453+
"join h2.test.employee c on c.dept = b.dept - 1 ")
454+
.collect().toSeq
455+
}
456+
457+
withSQLConf(
458+
SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") {
459+
val df = sql("SELECT min(a.dept + b.dept), min(a.dept), min(c.dept - 2) " +
460+
"from h2.test.employee a " +
461+
"join h2.test.employee b on b.dept = a.dept + 1 " +
462+
"join h2.test.employee c on c.dept = b.dept - 1 ")
463+
val joinNodes = df.queryExecution.optimizedPlan.collect {
464+
case j: Join => j
465+
}
466+
467+
val aggNodes = df.queryExecution.optimizedPlan.collect {
468+
case a: Aggregate => a
469+
}
470+
471+
assert(joinNodes.isEmpty)
472+
assert(aggNodes.isEmpty)
473+
checkAnswer(df, rows)
474+
}
475+
}
476+
477+
test("Test sort limit on top of join is pushed down") {
478+
val rows = withSQLConf(
479+
SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "false") {
480+
sql("SELECT min(a.dept + b.dept), a.dept, b.dept " +
481+
"from h2.test.employee a " +
482+
"join h2.test.employee b on b.dept = a.dept + 1 " +
483+
"GROUP BY a.dept, b.dept " +
484+
"ORDER BY a.dept " +
485+
"LIMIT 1")
486+
.collect().toSeq
487+
}
488+
489+
withSQLConf(
490+
SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") {
491+
val df = sql("SELECT min(a.dept + b.dept), a.dept, b.dept " +
492+
"from h2.test.employee a " +
493+
"join h2.test.employee b on b.dept = a.dept + 1 " +
494+
"GROUP BY a.dept, b.dept " +
495+
"ORDER BY a.dept " +
496+
"LIMIT 1")
497+
val joinNodes = df.queryExecution.optimizedPlan.collect {
498+
case j: Join => j
499+
}
500+
501+
val sortNodes = df.queryExecution.optimizedPlan.collect {
502+
case s: Sort => s
503+
}
504+
505+
val limitNodes = df.queryExecution.optimizedPlan.collect {
506+
case l: GlobalLimit => l
507+
}
508+
509+
assert(joinNodes.isEmpty)
510+
assert(sortNodes.isEmpty)
511+
assert(limitNodes.isEmpty)
512+
checkAnswer(df, rows)
513+
}
514+
}
515+
268516
test("simple scan") {
269517
checkAnswer(sql("SELECT * FROM h2.test.empty_table"), Seq())
270518
checkAnswer(sql("SELECT * FROM h2.test.people"), Seq(Row("fred", 1), Row("mary", 2)))

0 commit comments

Comments
 (0)
Please sign in to comment.