Skip to content

Commit

Permalink
[VL] Support outer join type for velox NestedLoopJoin
Browse files Browse the repository at this point in the history
  • Loading branch information
leoluan2009 committed Jan 22, 2025
1 parent 0f4489a commit 53deb11
Show file tree
Hide file tree
Showing 6 changed files with 8 additions and 18 deletions.
3 changes: 3 additions & 0 deletions cpp/velox/substrait/SubstraitToVeloxPlan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,9 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::
case ::substrait::CrossRel_JoinType::CrossRel_JoinType_JOIN_TYPE_LEFT:
joinType = core::JoinType::kLeft;
break;
case ::substrait::CrossRel_JoinType::CrossRel_JoinType_JOIN_TYPE_OUTER:
joinType = core::JoinType::kFull;
break;
default:
VELOX_NYI("Unsupported Join type: {}", std::to_string(crossRel.type()));
}
Expand Down
1 change: 1 addition & 0 deletions cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1060,6 +1060,7 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::CrossRel& crossR
switch (crossRel.type()) {
case ::substrait::CrossRel_JoinType_JOIN_TYPE_INNER:
case ::substrait::CrossRel_JoinType_JOIN_TYPE_LEFT:
case ::substrait::CrossRel_JoinType_JOIN_TYPE_OUTER:
break;
default:
LOG_VALIDATION_MSG("Unsupported Join type in CrossRel");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ abstract class BroadcastNestedLoopJoinExecTransformer(

def validateJoinTypeAndBuildSide(): ValidationResult = {
val result = joinType match {
case _: InnerLike | LeftOuter | RightOuter => ValidationResult.succeeded
case _: InnerLike | LeftOuter | RightOuter | FullOuter => ValidationResult.succeeded
case _ =>
ValidationResult.failed(s"$joinType join is not supported with BroadcastNestedLoopJoin")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import org.apache.gluten.GlutenBuildInfo
import org.apache.gluten.config.GlutenConfig
import org.apache.gluten.events.GlutenPlanFallbackEvent
import org.apache.gluten.execution.FileSourceScanExecTransformer
import org.apache.gluten.utils.BackendTestUtils

import org.apache.spark.SparkConf
import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent}
Expand Down Expand Up @@ -115,14 +114,7 @@ class GlutenFallbackSuite extends GlutenSQLTestsTrait with AdaptiveSparkPlanHelp

val id = runExecution("SELECT * FROM t1 FULL OUTER JOIN t2")
val execution = glutenStore.execution(id)
if (BackendTestUtils.isVeloxBackendLoaded()) {
assert(execution.get.numFallbackNodes == 1)
assert(
execution.get.fallbackNodeToReason.head._2
.contains("FullOuter join is not supported with BroadcastNestedLoopJoin"))
} else {
assert(execution.get.numFallbackNodes == 0)
}
assert(execution.get.numFallbackNodes == 0)
}

// [GLUTEN-4119] Skip add ReusedExchange to fallback node
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,7 @@ class GlutenFallbackSuite extends GlutenSQLTestsTrait with AdaptiveSparkPlanHelp
val id = runExecution("SELECT * FROM t1 FULL OUTER JOIN t2")
val execution = glutenStore.execution(id)
if (BackendTestUtils.isVeloxBackendLoaded()) {
assert(execution.get.numFallbackNodes == 1)
assert(
execution.get.fallbackNodeToReason.head._2
.contains("FullOuter join is not supported with BroadcastNestedLoopJoin"))
assert(execution.get.numFallbackNodes == 0)
} else {
assert(execution.get.numFallbackNodes == 2)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,7 @@ class GlutenFallbackSuite extends GlutenSQLTestsTrait with AdaptiveSparkPlanHelp
val id = runExecution("SELECT * FROM t1 FULL OUTER JOIN t2")
val execution = glutenStore.execution(id)
if (BackendTestUtils.isVeloxBackendLoaded()) {
assert(execution.get.numFallbackNodes == 1)
assert(
execution.get.fallbackNodeToReason.head._2
.contains("FullOuter join is not supported with BroadcastNestedLoopJoin"))
assert(execution.get.numFallbackNodes == 0)
} else {
assert(execution.get.numFallbackNodes == 2)
}
Expand Down

0 comments on commit 53deb11

Please sign in to comment.