Skip to content

Commit c90f33e

Browse files
introduce join pushdown for dsv2
1 parent 5b07e52 commit c90f33e

File tree

19 files changed

+851
-18
lines changed

19 files changed

+851
-18
lines changed

common/utils/src/main/resources/error/error-conditions.json

+5
Original file line numberDiff line numberDiff line change
@@ -9403,6 +9403,11 @@
94039403
"The number of fields (<numFields>) in the partition identifier is not equal to the partition schema length (<schemaLen>). The identifier might not refer to one partition."
94049404
]
94059405
},
9406+
"_LEGACY_ERROR_TEMP_3209": {
9407+
"message" : [
9408+
"Unexpected join type: <joinType>"
9409+
]
9410+
},
94069411
"_LEGACY_ERROR_TEMP_3215" : {
94079412
"message" : [
94089413
"Expected a Boolean type expression in replaceNullWithFalse, but got the type <dataType> in <expr>."
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.connector.join;
19+
20+
import org.apache.spark.annotation.Evolving;
21+
22+
/**
23+
* Base class of the public Join type API.
24+
*
25+
* @since 4.0.0
26+
*/
27+
@Evolving
28+
public final class Inner implements JoinType { }
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.connector.join;
19+
20+
import org.apache.spark.annotation.Evolving;
21+
import org.apache.spark.sql.connector.expressions.Expression;
22+
import org.apache.spark.sql.connector.expressions.NamedReference;
23+
24+
/**
25+
* Represents a column reference used in DSv2 Join pushdown.
26+
*
27+
* @since 4.0.0
28+
*/
29+
@Evolving
30+
final public class JoinColumn implements NamedReference {
31+
public JoinColumn(String[] qualifier, String name, Boolean isInLeftSideOfJoin) {
32+
this.qualifier = qualifier;
33+
this.name = name;
34+
this.isInLeftSideOfJoin = isInLeftSideOfJoin;
35+
}
36+
37+
public String[] qualifier;
38+
public String name;
39+
public Boolean isInLeftSideOfJoin;
40+
41+
@Override
42+
public String[] fieldNames() {
43+
String[] fullyQualified = new String[qualifier.length + 1];
44+
System.arraycopy(qualifier, 0, fullyQualified, 0, qualifier.length);
45+
fullyQualified[qualifier.length] = name;
46+
return qualifier;
47+
}
48+
49+
@Override
50+
public Expression[] children() { return EMPTY_EXPRESSION; }
51+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.connector.join;
19+
20+
import org.apache.spark.annotation.Evolving;
21+
22+
/**
23+
* Base class of the public Join type API.
24+
*
25+
* @since 4.0.0
26+
*/
27+
@Evolving
28+
public interface JoinType { }
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.connector.read;
19+
20+
import java.util.Optional;
21+
22+
import org.apache.spark.annotation.Evolving;
23+
import org.apache.spark.sql.connector.expressions.filter.Predicate;
24+
import org.apache.spark.sql.connector.join.JoinType;
25+
import org.apache.spark.sql.types.StructType;
26+
27+
/**
28+
* A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to
29+
* push down join operators.
30+
*
31+
* @since 4.0.0
32+
*/
33+
@Evolving
34+
public interface SupportsPushDownJoin extends ScanBuilder {
35+
boolean isRightSideCompatibleForJoin(SupportsPushDownJoin other);
36+
37+
boolean pushJoin(
38+
SupportsPushDownJoin other,
39+
JoinType joinType,
40+
Optional<Predicate> condition,
41+
StructType leftRequiredSchema,
42+
StructType rightRequiredSchema
43+
);
44+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.connector.util;
19+
20+
import org.apache.spark.SparkIllegalArgumentException;
21+
import org.apache.spark.sql.connector.join.*;
22+
23+
import java.util.HashMap;
24+
import java.util.Map;
25+
26+
/**
27+
* The builder to generate SQL for specific Join type.
28+
*
29+
* @since 4.0.0
30+
*/
31+
public class JoinTypeSQLBuilder {
32+
public String build(JoinType joinType) {
33+
if (joinType instanceof Inner inner) {
34+
return visitInnerJoin(inner);
35+
} else {
36+
return visitUnexpectedJoinType(joinType);
37+
}
38+
}
39+
40+
protected String visitInnerJoin(Inner inner) {
41+
return "INNER JOIN";
42+
}
43+
44+
protected String visitUnexpectedJoinType(JoinType joinType) throws IllegalArgumentException {
45+
Map<String, String> params = new HashMap<>();
46+
params.put("joinType", String.valueOf(joinType));
47+
throw new SparkIllegalArgumentException("_LEGACY_ERROR_TEMP_3209", params);
48+
}
49+
}

sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java

+10
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.connector.util;
1919

2020
import java.util.ArrayList;
21+
import java.util.Arrays;
2122
import java.util.List;
2223
import java.util.Map;
2324
import java.util.StringJoiner;
@@ -42,6 +43,7 @@
4243
import org.apache.spark.sql.connector.expressions.aggregate.GeneralAggregateFunc;
4344
import org.apache.spark.sql.connector.expressions.aggregate.Sum;
4445
import org.apache.spark.sql.connector.expressions.aggregate.UserDefinedAggregateFunc;
46+
import org.apache.spark.sql.connector.join.JoinColumn;
4547
import org.apache.spark.sql.types.DataType;
4648

4749
/**
@@ -75,6 +77,8 @@ protected String escapeSpecialCharsForLikePattern(String str) {
7577
public String build(Expression expr) {
7678
if (expr instanceof Literal literal) {
7779
return visitLiteral(literal);
80+
} else if (expr instanceof JoinColumn column) {
81+
return visitJoinColumn(column);
7882
} else if (expr instanceof NamedReference namedReference) {
7983
return visitNamedReference(namedReference);
8084
} else if (expr instanceof Cast cast) {
@@ -174,6 +178,12 @@ protected String visitNamedReference(NamedReference namedRef) {
174178
return namedRef.toString();
175179
}
176180

181+
protected String visitJoinColumn(JoinColumn column) {
182+
List<String> fullyQualifiedName = new ArrayList<>(Arrays.asList(column.qualifier));
183+
fullyQualifiedName.add(column.name);
184+
return joinListToString(fullyQualifiedName, ".", "", "");
185+
}
186+
177187
protected String visitIn(String v, List<String> list) {
178188
if (list.isEmpty()) {
179189
return "CASE WHEN " + v + " IS NULL THEN NULL ELSE FALSE END";

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala

+8
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,14 @@ case class AttributeReference(
394394
}
395395
}
396396

397+
case class JoinColumnReference(
398+
originalReference: AttributeReference,
399+
isReferringColumnFromLeftSubquery: Boolean = true)
400+
extends LeafExpression with Unevaluable {
401+
override def nullable: Boolean = originalReference.nullable
402+
override def dataType: DataType = originalReference.dataType
403+
}
404+
397405
/**
398406
* A place holder used when printing expressions without debugging information such as the
399407
* expression id or the unresolved indicator.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala

+7
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import org.apache.spark.sql.connector.catalog.functions.ScalarFunction
2626
import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, Extract => V2Extract, FieldReference, GeneralScalarExpression, LiteralValue, NullOrdering, SortDirection, SortValue, UserDefinedScalarFunc}
2727
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum, UserDefinedAggregateFunc}
2828
import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate}
29+
import org.apache.spark.sql.connector.join.JoinColumn
2930
import org.apache.spark.sql.internal.SQLConf
3031
import org.apache.spark.sql.types.{BooleanType, DataType, IntegerType, StringType}
3132

@@ -79,6 +80,12 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) extends L
7980
case Literal(true, BooleanType) => Some(new AlwaysTrue())
8081
case Literal(false, BooleanType) => Some(new AlwaysFalse())
8182
case Literal(value, dataType) => Some(LiteralValue(value, dataType))
83+
case joinRefColumn: JoinColumnReference =>
84+
Some (new JoinColumn(
85+
Array(),
86+
joinRefColumn.originalReference.name,
87+
joinRefColumn.isReferringColumnFromLeftSubquery
88+
))
8289
case col @ ColumnOrField(nameParts) =>
8390
val ref = FieldReference(nameParts)
8491
if (isPredicate && col.dataType.isInstanceOf[BooleanType]) {

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

+10
Original file line numberDiff line numberDiff line change
@@ -1640,6 +1640,14 @@ object SQLConf {
16401640
.booleanConf
16411641
.createWithDefault(!Utils.isTesting)
16421642

1643+
val DATA_SOURCE_V2_JOIN_PUSHDOWN =
1644+
buildConf("spark.sql.optimizer.datasourceV2JoinPushdown")
1645+
.internal()
1646+
.doc("When this config is set to true, join is tried to be pushed down" +
1647+
"for DSv2 data sources in V2ScanRelationPushdown optimization rule.")
1648+
.booleanConf
1649+
.createWithDefault(false)
1650+
16431651
// This is used to set the default data source
16441652
val DEFAULT_DATA_SOURCE_NAME = buildConf("spark.sql.sources.default")
16451653
.doc("The default data source to use in input/output.")
@@ -5988,6 +5996,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
59885996

59895997
def expressionTreeChangeLogLevel: Level = getConf(EXPRESSION_TREE_CHANGE_LOG_LEVEL)
59905998

5999+
def dataSourceV2JoinPushdown: Boolean = getConf(DATA_SOURCE_V2_JOIN_PUSHDOWN)
6000+
59916001
def dynamicPartitionPruningEnabled: Boolean = getConf(DYNAMIC_PARTITION_PRUNING_ENABLED)
59926002

59936003
def dynamicPartitionPruningUseStats: Boolean = getConf(DYNAMIC_PARTITION_PRUNING_USE_STATS)

sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala

+8-1
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,12 @@ case class RowDataSourceScanExec(
189189
seqToString(markedFilters.toSeq)
190190
}
191191

192+
val pushedJoins = if (pushedDownOperators.pushedJoins.nonEmpty) {
193+
Map("PushedJoins" -> seqToString(pushedDownOperators.pushedJoins))
194+
} else {
195+
Map()
196+
}
197+
192198
Map("ReadSchema" -> requiredSchema.catalogString,
193199
"PushedFilters" -> pushedFilters) ++
194200
pushedDownOperators.aggregation.fold(Map[String, String]()) { v =>
@@ -200,7 +206,8 @@ case class RowDataSourceScanExec(
200206
offsetInfo ++
201207
pushedDownOperators.sample.map(v => "PushedSample" ->
202208
s"SAMPLE (${(v.upperBound - v.lowerBound) * 100}) ${v.withReplacement} SEED(${v.seed})"
203-
)
209+
) ++
210+
pushedJoins
204211
}
205212

206213
// Don't care about `rdd` and `tableIdentifier`, and `stream` when canonicalizing.

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala

+9
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
3737
import org.apache.spark.sql.catalyst.expressions._
3838
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
3939
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
40+
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
4041
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, InsertIntoDir, InsertIntoStatement, LogicalPlan, Project}
4142
import org.apache.spark.sql.catalyst.rules.Rule
4243
import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2
@@ -47,6 +48,7 @@ import org.apache.spark.sql.connector.catalog.{SupportsRead, V1Table}
4748
import org.apache.spark.sql.connector.catalog.TableCapability._
4849
import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, NullOrdering, SortDirection, SortOrder => V2SortOrder, SortValue}
4950
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation}
51+
import org.apache.spark.sql.connector.join.{Inner => V2Inner, JoinType => V2JoinType}
5052
import org.apache.spark.sql.errors.QueryCompilationErrors
5153
import org.apache.spark.sql.execution
5254
import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan}
@@ -501,6 +503,13 @@ object DataSourceStrategy
501503
}
502504
}
503505

506+
def translateJoinType(joinType: JoinType): Option[V2JoinType] = {
507+
joinType match {
508+
case Inner => Some(new V2Inner)
509+
case _ => None
510+
}
511+
}
512+
504513
/**
505514
* Convert RDD of Row into RDD of InternalRow with objects in catalyst types
506515
*/

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala

+7
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ class JDBCOptions(
104104
}
105105
}
106106

107+
var containsJoinInQuery: Boolean = false
108+
107109
// ------------------------------------------------------------
108110
// Optional parameters
109111
// ------------------------------------------------------------
@@ -215,6 +217,10 @@ class JDBCOptions(
215217
// This only applies to Data Source V2 JDBC
216218
val pushDownTableSample = parameters.getOrElse(JDBC_PUSHDOWN_TABLESAMPLE, "true").toBoolean
217219

220+
// An option to allow/disallow pushing down JOIN into JDBC data source
221+
// This only applies to Data Source V2 JDBC
222+
val pushDownJoin = parameters.getOrElse(JDBC_PUSHDOWN_JOIN, "true").toBoolean
223+
218224
// The local path of user's keytab file, which is assumed to be pre-uploaded to all nodes either
219225
// by --files option of spark-submit or manually
220226
val keytab = {
@@ -321,6 +327,7 @@ object JDBCOptions {
321327
val JDBC_PUSHDOWN_LIMIT = newOption("pushDownLimit")
322328
val JDBC_PUSHDOWN_OFFSET = newOption("pushDownOffset")
323329
val JDBC_PUSHDOWN_TABLESAMPLE = newOption("pushDownTableSample")
330+
val JDBC_PUSHDOWN_JOIN = newOption("pushDownJoin")
324331
val JDBC_KEYTAB = newOption("keytab")
325332
val JDBC_PRINCIPAL = newOption("principal")
326333
val JDBC_TABLE_COMMENT = newOption("tableComment")

0 commit comments

Comments
 (0)