diff --git a/.github/workflows/velox_docker.yml b/.github/workflows/velox_docker.yml
index 6329750d22fb..1ed7794aa501 100644
--- a/.github/workflows/velox_docker.yml
+++ b/.github/workflows/velox_docker.yml
@@ -71,8 +71,8 @@ jobs:
strategy:
fail-fast: false
matrix:
- os: ["ubuntu:20.04", "ubuntu:22.04"]
- spark: ["spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5"]
+ os: [ "ubuntu:20.04", "ubuntu:22.04" ]
+ spark: [ "spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5" ]
java: [ "java-8", "java-17" ]
# Spark supports JDK17 since 3.3 and later, see https://issues.apache.org/jira/browse/SPARK-33772
exclude:
@@ -119,9 +119,9 @@ jobs:
strategy:
fail-fast: false
matrix:
- os: ["centos:7", "centos:8"]
- spark: ["spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5"]
- java: ["java-8", "java-17"]
+ os: [ "centos:7", "centos:8" ]
+ spark: [ "spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5" ]
+ java: [ "java-8", "java-17" ]
# Spark supports JDK17 since 3.3 and later, see https://issues.apache.org/jira/browse/SPARK-33772
exclude:
- spark: spark-3.2
@@ -156,24 +156,40 @@ jobs:
wget https://downloads.apache.org/maven/maven-3/3.8.8/binaries/apache-maven-3.8.8-bin.tar.gz
tar -xvf apache-maven-3.8.8-bin.tar.gz
mv apache-maven-3.8.8 /usr/lib/maven
- - name: Build and run TPCH/DS
+ - name: Set environment variables
run: |
- cd $GITHUB_WORKSPACE/
- export MAVEN_HOME=/usr/lib/maven
- export PATH=${PATH}:${MAVEN_HOME}/bin
+ echo "MAVEN_HOME=/usr/lib/maven" >> $GITHUB_ENV
+ echo "PATH=${PATH}:/usr/lib/maven/bin" >> $GITHUB_ENV
if [ "${{ matrix.java }}" = "java-17" ]; then
- export JAVA_HOME=/usr/lib/jvm/java-17-openjdk
+ echo "JAVA_HOME=/usr/lib/jvm/java-17-openjdk" >> $GITHUB_ENV
else
- export JAVA_HOME=/usr/lib/jvm/java-1.8.0-openjdk
+ echo "JAVA_HOME=/usr/lib/jvm/java-1.8.0-openjdk" >> $GITHUB_ENV
fi
+ - name: Build gluten-it
+ run: |
echo "JAVA_HOME: $JAVA_HOME"
+ cd $GITHUB_WORKSPACE/
mvn clean install -P${{ matrix.spark }} -P${{ matrix.java }} -Pbackends-velox -DskipTests
- cd $GITHUB_WORKSPACE/tools/gluten-it
- mvn clean install -P${{ matrix.spark }} -P${{ matrix.java }} \
- && GLUTEN_IT_JVM_ARGS=-Xmx5G sbin/gluten-it.sh queries-compare \
+ cd $GITHUB_WORKSPACE/tools/gluten-it
+ mvn clean install -P${{ matrix.spark }} -P${{ matrix.java }}
+ - name: Run TPC-H / TPC-DS
+ run: |
+ echo "JAVA_HOME: $JAVA_HOME"
+ cd $GITHUB_WORKSPACE/tools/gluten-it
+ GLUTEN_IT_JVM_ARGS=-Xmx5G sbin/gluten-it.sh queries-compare \
--local --preset=velox --benchmark-type=h --error-on-memleak --off-heap-size=10g -s=1.0 --threads=16 --iterations=1 \
&& GLUTEN_IT_JVM_ARGS=-Xmx5G sbin/gluten-it.sh queries-compare \
--local --preset=velox --benchmark-type=ds --error-on-memleak --off-heap-size=10g -s=1.0 --threads=16 --iterations=1
+ - name: Run TPC-H / TPC-DS with ACBO
+ run: |
+ echo "JAVA_HOME: $JAVA_HOME"
+ cd $GITHUB_WORKSPACE/tools/gluten-it
+ GLUTEN_IT_JVM_ARGS=-Xmx5G sbin/gluten-it.sh queries-compare \
+ --local --preset=velox --benchmark-type=h --error-on-memleak --off-heap-size=10g -s=1.0 --threads=16 --iterations=1 \
+ --extra-conf=spark.gluten.sql.advanced.cbo.enabled=true \
+ && GLUTEN_IT_JVM_ARGS=-Xmx5G sbin/gluten-it.sh queries-compare \
+ --local --preset=velox --benchmark-type=ds --error-on-memleak --off-heap-size=10g -s=1.0 --threads=16 --iterations=1 \
+ --extra-conf=spark.gluten.sql.advanced.cbo.enabled=true
# run-tpc-test-centos8-oom-randomkill:
# needs: build-native-lib
diff --git a/backends-velox/pom.xml b/backends-velox/pom.xml
index f1455f851c8b..01e0280ad755 100755
--- a/backends-velox/pom.xml
+++ b/backends-velox/pom.xml
@@ -50,6 +50,13 @@
test-jar
test
+
+ io.glutenproject
+ gluten-cbo-common
+ ${project.version}
+ test-jar
+ test
+
org.apache.spark
spark-core_${scala.binary.version}
diff --git a/backends-velox/src/test/scala/io/glutenproject/planner/VeloxCboSuite.scala b/backends-velox/src/test/scala/io/glutenproject/planner/VeloxCboSuite.scala
new file mode 100644
index 000000000000..83ce7c69ad91
--- /dev/null
+++ b/backends-velox/src/test/scala/io/glutenproject/planner/VeloxCboSuite.scala
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.planner
+
+import io.glutenproject.cbo.{Cbo, CboSuiteBase}
+import io.glutenproject.cbo.path.CboPath
+import io.glutenproject.cbo.property.PropertySet
+import io.glutenproject.cbo.rule.{CboRule, Shape, Shapes}
+import io.glutenproject.planner.property.GlutenProperties.{Conventions, Schemas}
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.execution._
+import org.apache.spark.sql.test.SharedSparkSession
+
+class VeloxCboSuite extends SharedSparkSession {
+ import VeloxCboSuite._
+
+ test("C2R, R2C - basic") {
+ val in = RowUnary(RowLeaf())
+ val planner = newCbo().newPlanner(in)
+ val out = planner.plan()
+ assert(out == RowUnary(RowLeaf()))
+ }
+
+ test("C2R, R2C - explicitly requires any properties") {
+ val in = RowUnary(RowLeaf())
+ val planner =
+ newCbo().newPlanner(in, PropertySet(List(Conventions.ANY, Schemas.ANY)))
+ val out = planner.plan()
+ assert(out == RowUnary(RowLeaf()))
+ }
+
+ test("C2R, R2C - requires columnar output") {
+ val in = RowUnary(RowLeaf())
+ val planner =
+ newCbo().newPlanner(in, PropertySet(List(Conventions.VANILLA_COLUMNAR, Schemas.ANY)))
+ val out = planner.plan()
+ assert(out == RowToColumnarExec(RowUnary(RowLeaf())))
+ }
+
+ test("C2R, R2C - insert c2rs / r2cs") {
+ val in =
+ ColumnarUnary(RowUnary(RowUnary(ColumnarUnary(RowUnary(RowUnary(ColumnarUnary(RowLeaf())))))))
+ val planner =
+ newCbo().newPlanner(in, PropertySet(List(Conventions.ROW_BASED, Schemas.ANY)))
+ val out = planner.plan()
+ assert(out == ColumnarToRowExec(ColumnarUnary(
+ RowToColumnarExec(RowUnary(RowUnary(ColumnarToRowExec(ColumnarUnary(RowToColumnarExec(
+ RowUnary(RowUnary(ColumnarToRowExec(ColumnarUnary(RowToColumnarExec(RowLeaf()))))))))))))))
+ val paths = planner.newState().memoState().collectAllPaths(CboPath.INF_DEPTH).toList
+ val pathCount = paths.size
+ assert(pathCount == 165)
+ }
+
+ test("C2R, R2C - Row unary convertible to Columnar") {
+ object ConvertRowUnaryToColumnar extends CboRule[SparkPlan] {
+ override def shift(node: SparkPlan): Iterable[SparkPlan] = node match {
+ case RowUnary(child) => List(ColumnarUnary(child))
+ case other => List.empty
+ }
+
+ override def shape(): Shape[SparkPlan] = Shapes.fixedHeight(1)
+ }
+
+ val in =
+ ColumnarUnary(RowUnary(RowUnary(ColumnarUnary(RowUnary(RowUnary(ColumnarUnary(RowLeaf())))))))
+ val planner =
+ newCbo(List(ConvertRowUnaryToColumnar))
+ .newPlanner(in, PropertySet(List(Conventions.ROW_BASED, Schemas.ANY)))
+ val out = planner.plan()
+ assert(out == ColumnarToRowExec(ColumnarUnary(ColumnarUnary(ColumnarUnary(
+ ColumnarUnary(ColumnarUnary(ColumnarUnary(ColumnarUnary(RowToColumnarExec(RowLeaf()))))))))))
+ val paths = planner.newState().memoState().collectAllPaths(CboPath.INF_DEPTH).toList
+ val pathCount = paths.size
+ assert(pathCount == 1094)
+ }
+}
+
+object VeloxCboSuite extends CboSuiteBase {
+ def newCbo(): Cbo[SparkPlan] = {
+ GlutenOptimization().asInstanceOf[Cbo[SparkPlan]]
+ }
+
+ def newCbo(cboRules: Seq[CboRule[SparkPlan]]): Cbo[SparkPlan] = {
+ GlutenOptimization(cboRules).asInstanceOf[Cbo[SparkPlan]]
+ }
+
+ case class RowLeaf() extends LeafExecNode {
+ override def supportsColumnar: Boolean = false
+ override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException()
+ override def output: Seq[Attribute] = Seq.empty
+ }
+
+ case class RowUnary(child: SparkPlan) extends UnaryExecNode {
+ override def supportsColumnar: Boolean = false
+ override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException()
+ override def output: Seq[Attribute] = child.output
+ override protected def withNewChildInternal(newChild: SparkPlan): RowUnary =
+ copy(child = newChild)
+ }
+
+ case class ColumnarUnary(child: SparkPlan) extends UnaryExecNode {
+ override def supportsColumnar: Boolean = true
+ override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException()
+ override def output: Seq[Attribute] = child.output
+ override protected def withNewChildInternal(newChild: SparkPlan): ColumnarUnary =
+ copy(child = newChild)
+ }
+}
diff --git a/docs/Configuration.md b/docs/Configuration.md
index b68717d048e3..e1f8a8fccf1b 100644
--- a/docs/Configuration.md
+++ b/docs/Configuration.md
@@ -22,6 +22,7 @@ You can add these configurations into spark-defaults.conf to enable or disable t
| spark.plugins | To load Gluten's components by Spark's plug-in loader | io.glutenproject.GlutenPlugin |
| spark.shuffle.manager | To turn on Gluten Columnar Shuffle Plugin | org.apache.spark.shuffle.sort.ColumnarShuffleManager |
| spark.gluten.enabled | Enable Gluten, default is true. Just an experimental property. Recommend to enable/disable Gluten through the setting for `spark.plugins`. | true |
+| spark.gluten.sql.advanced.cbo.enabled | Experimental: Enables Gluten's advanced CBO features during physical planning. E.g, More efficient fallback strategy, etc. The option can be turned on and off individually despite vanilla Spark's CBO settings. Note, Gluten's query optimizer may still adopt a subset of its advanced CBO capabilities even this option is off. Enabling it would cause Gluten consider using CBO for optimization more aggressively. Note, this feature is still in development and may not bring performance profits. | false |
| spark.gluten.sql.columnar.maxBatchSize | Number of rows to be processed in each batch. Default value is 4096. | 4096 |
| spark.gluten.memory.isolation | (Experimental) Enable isolated memory mode. If true, Gluten controls the maximum off-heap memory can be used by each task to X, X = executor memory / max task slots. It's recommended to set true if Gluten serves concurrent queries within a single session, since not all memory Gluten allocated is guaranteed to be spillable. In the case, the feature should be enabled to avoid OOM. Note when true, setting spark.memory.storageFraction to a lower value is suggested since storage memory is considered non-usable by Gluten. | false |
| spark.gluten.sql.columnar.scanOnly | When enabled, this config will overwrite all other operators' enabling, and only Scan and Filter pushdown will be offloaded to native. | false |
diff --git a/gluten-cbo/common/pom.xml b/gluten-cbo/common/pom.xml
new file mode 100644
index 000000000000..7888ced36bb4
--- /dev/null
+++ b/gluten-cbo/common/pom.xml
@@ -0,0 +1,12 @@
+
+ 4.0.0
+
+ io.glutenproject
+ gluten-cbo
+ 1.2.0-SNAPSHOT
+
+ gluten-cbo-common
+ Gluten Cbo Common
+
diff --git a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/Cbo.scala b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/Cbo.scala
new file mode 100644
index 000000000000..fa735ec2fbed
--- /dev/null
+++ b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/Cbo.scala
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo
+
+import io.glutenproject.cbo.property.PropertySet
+import io.glutenproject.cbo.rule.CboRule
+
+import scala.collection.mutable
+
+/**
+ * Entrypoint of ACBO (Advanced CBO)'s search engine. See basic introduction of ACBO:
+ * https://github.com/apache/incubator-gluten/issues/5057.
+ */
+trait Optimization[T <: AnyRef] {
+ def newPlanner(
+ plan: T,
+ constraintSet: PropertySet[T],
+ altConstraintSets: Seq[PropertySet[T]]): CboPlanner[T]
+
+ def propSetsOf(plan: T): PropertySet[T]
+
+ def withNewConfig(confFunc: CboConfig => CboConfig): Optimization[T]
+}
+
+object Optimization {
+ def apply[T <: AnyRef](
+ costModel: CostModel[T],
+ planModel: PlanModel[T],
+ propertyModel: PropertyModel[T],
+ explain: CboExplain[T],
+ ruleFactory: CboRule.Factory[T]): Optimization[T] = {
+ Cbo(costModel, planModel, propertyModel, explain, ruleFactory)
+ }
+
+ implicit class OptimizationImplicits[T <: AnyRef](opt: Optimization[T]) {
+ def newPlanner(plan: T): CboPlanner[T] = {
+ opt.newPlanner(plan, opt.propSetsOf(plan), List.empty)
+ }
+ def newPlanner(plan: T, constraintSet: PropertySet[T]): CboPlanner[T] = {
+ opt.newPlanner(plan, constraintSet, List.empty)
+ }
+ }
+}
+
+class Cbo[T <: AnyRef] private (
+ val config: CboConfig,
+ val costModel: CostModel[T],
+ val planModel: PlanModel[T],
+ val propertyModel: PropertyModel[T],
+ val explain: CboExplain[T],
+ val ruleFactory: CboRule.Factory[T])
+ extends Optimization[T] {
+ import Cbo._
+
+ override def withNewConfig(confFunc: CboConfig => CboConfig): Cbo[T] = {
+ new Cbo(confFunc(config), costModel, planModel, propertyModel, explain, ruleFactory)
+ }
+
+ // Normal groups start with ID 0, so it's safe to use -1 to do validation.
+ private val dummyGroup: T =
+ planModel.newGroupLeaf(-1, PropertySet(Seq.empty))
+ private val infCost: Cost = costModel.makeInfCost()
+
+ validateModels()
+
+ private def assertThrows(message: String)(u: => Unit): Unit = {
+ var notThrew: Boolean = false
+ try {
+ u
+ notThrew = true
+ } catch {
+ case _: Exception =>
+ }
+ assert(!notThrew, message)
+ }
+
+ private def validateModels(): Unit = {
+ // Node groups are leafs.
+ assert(planModel.childrenOf(dummyGroup) == List.empty)
+ assertThrows(
+ "Group is not allowed to have cost. It's expected to throw an exception when " +
+ "getting its cost but not") {
+ // Node groups don't have user-defined cost, expect exception here.
+ costModel.costOf(dummyGroup)
+ }
+ propertyModel.propertyDefs.foreach {
+ propDef =>
+ // Node groups don't have user-defined property, expect exception here.
+ assertThrows(
+ "Group is not allowed to return its property directly to optimizer (optimizer already" +
+ " knew that). It's expected to throw an exception when getting its property but not") {
+ propDef.getProperty(dummyGroup)
+ }
+ }
+ }
+
+ private val propSetFactory: PropertySetFactory[T] = PropertySetFactory(this)
+
+ override def newPlanner(
+ plan: T,
+ constraintSet: PropertySet[T],
+ altConstraintSets: Seq[PropertySet[T]]): CboPlanner[T] = {
+ CboPlanner(this, altConstraintSets, constraintSet, plan)
+ }
+
+ override def propSetsOf(plan: T): PropertySet[T] = propertySetFactory().get(plan)
+
+ private[cbo] def withNewChildren(node: T, newChildren: Seq[T]): T = {
+ val oldChildren = planModel.childrenOf(node)
+ assert(newChildren.size == oldChildren.size)
+ val out = planModel.withNewChildren(node, newChildren)
+ assert(planModel.childrenOf(out).size == newChildren.size)
+ out
+ }
+
+ private[cbo] def isGroupLeaf(node: T): Boolean = {
+ planModel.isGroupLeaf(node)
+ }
+
+ private[cbo] def isLeaf(node: T): Boolean = {
+ planModel.childrenOf(node).isEmpty
+ }
+
+ private[cbo] def isCanonical(node: T): Boolean = {
+ assert(!planModel.isGroupLeaf(node))
+ planModel.childrenOf(node).forall(child => planModel.isGroupLeaf(child))
+ }
+
+ private[cbo] def getChildrenGroupIds(n: T): Seq[Int] = {
+ assert(isCanonical(n))
+ planModel
+ .childrenOf(n)
+ .map(child => planModel.getGroupId(child))
+ }
+
+ private[cbo] def propertySetFactory(): PropertySetFactory[T] = propSetFactory
+
+ private[cbo] def dummyGroupLeaf(): T = {
+ dummyGroup
+ }
+
+ private[cbo] def getInfCost(): Cost = infCost
+
+ private[cbo] def isInfCost(cost: Cost) = costModel.costComparator().equiv(cost, infCost)
+}
+
+object Cbo {
+ private[cbo] def apply[T <: AnyRef](
+ costModel: CostModel[T],
+ planModel: PlanModel[T],
+ propertyModel: PropertyModel[T],
+ explain: CboExplain[T],
+ ruleFactory: CboRule.Factory[T]): Cbo[T] = {
+ new Cbo[T](CboConfig(), costModel, planModel, propertyModel, explain, ruleFactory)
+ }
+
+ trait PropertySetFactory[T <: AnyRef] {
+ def get(node: T): PropertySet[T]
+ def childrenConstraintSets(constraintSet: PropertySet[T], node: T): Seq[PropertySet[T]]
+ }
+
+ private object PropertySetFactory {
+ def apply[T <: AnyRef](cbo: Cbo[T]): PropertySetFactory[T] = new PropertySetFactoryImpl[T](cbo)
+
+ private class PropertySetFactoryImpl[T <: AnyRef](val cbo: Cbo[T])
+ extends PropertySetFactory[T] {
+ private val propDefs: Seq[PropertyDef[T, _ <: Property[T]]] = cbo.propertyModel.propertyDefs
+
+ override def get(node: T): PropertySet[T] = {
+ val m: Map[PropertyDef[T, _ <: Property[T]], Property[T]] =
+ propDefs.map(propDef => (propDef, propDef.getProperty(node))).toMap
+ PropertySet[T](m)
+ }
+
+ override def childrenConstraintSets(
+ constraintSet: PropertySet[T],
+ node: T): Seq[PropertySet[T]] = {
+ val builder: Seq[mutable.Map[PropertyDef[T, _ <: Property[T]], Property[T]]] =
+ cbo.planModel
+ .childrenOf(node)
+ .map(_ => mutable.Map[PropertyDef[T, _ <: Property[T]], Property[T]]())
+
+ propDefs
+ .foldLeft(builder) {
+ (
+ builder: Seq[mutable.Map[PropertyDef[T, _ <: Property[T]], Property[T]]],
+ propDef: PropertyDef[T, _ <: Property[T]]) =>
+ val constraint = constraintSet.get(propDef)
+ val childrenConstraints = propDef.getChildrenConstraints(constraint, node)
+ builder.zip(childrenConstraints).map {
+ case (childBuilder, childConstraint) =>
+ childBuilder += (propDef -> childConstraint)
+ }
+ }
+ .map {
+ builder: mutable.Map[PropertyDef[T, _ <: Property[T]], Property[T]] =>
+ PropertySet[T](builder.toMap)
+ }
+ }
+ }
+ }
+}
diff --git a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/CboCluster.scala b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/CboCluster.scala
new file mode 100644
index 000000000000..153050cefb62
--- /dev/null
+++ b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/CboCluster.scala
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo
+
+import io.glutenproject.cbo.memo.MemoTable
+import io.glutenproject.cbo.property.PropertySet
+
+import scala.collection.mutable
+
+trait CboClusterKey
+
+object CboClusterKey {
+ implicit class CboClusterKeyImplicits[T <: AnyRef](key: CboClusterKey) {
+ def propSets(memoTable: MemoTable[T]): Set[PropertySet[T]] = {
+ memoTable.getClusterPropSets(key)
+ }
+ }
+}
+
+trait CboCluster[T <: AnyRef] {
+ def nodes(): Iterable[CanonicalNode[T]]
+}
+
+object CboCluster {
+ // Node cluster.
+ trait MutableCboCluster[T <: AnyRef] extends CboCluster[T] {
+ def cbo(): Cbo[T]
+ def contains(t: CanonicalNode[T]): Boolean
+ def add(t: CanonicalNode[T]): Unit
+ }
+
+ object MutableCboCluster {
+ def apply[T <: AnyRef](cbo: Cbo[T]): MutableCboCluster[T] = {
+ new RegularMutableCboCluster(cbo)
+ }
+
+ private class RegularMutableCboCluster[T <: AnyRef](val cbo: Cbo[T])
+ extends MutableCboCluster[T] {
+ private val buffer: mutable.Set[CanonicalNode[T]] =
+ mutable.Set()
+
+ override def contains(t: CanonicalNode[T]): Boolean = {
+ buffer.contains(t)
+ }
+
+ override def add(t: CanonicalNode[T]): Unit = {
+ assert(!buffer.contains(t))
+ buffer += t
+ }
+
+ override def nodes(): Iterable[CanonicalNode[T]] = {
+ buffer
+ }
+ }
+ }
+
+ case class ImmutableCboCluster[T <: AnyRef] private (
+ cbo: Cbo[T],
+ override val nodes: Set[CanonicalNode[T]])
+ extends CboCluster[T]
+
+ object ImmutableCboCluster {
+ def apply[T <: AnyRef](cbo: Cbo[T], cluster: CboCluster[T]): ImmutableCboCluster[T] = {
+ ImmutableCboCluster(cbo, cluster.nodes().toSet)
+ }
+ }
+}
diff --git a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/CboConfig.scala b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/CboConfig.scala
new file mode 100644
index 000000000000..cc6fe3792319
--- /dev/null
+++ b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/CboConfig.scala
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo
+
+import io.glutenproject.cbo.CboConfig._
+
+case class CboConfig(
+ plannerType: PlannerType = PlannerType.Dp
+)
+
+object CboConfig {
+ sealed trait PlannerType
+ object PlannerType {
+ case object Exhaustive extends PlannerType
+ case object Dp extends PlannerType
+ }
+}
diff --git a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/CboExplain.scala b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/CboExplain.scala
new file mode 100644
index 000000000000..8a2f908789a9
--- /dev/null
+++ b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/CboExplain.scala
@@ -0,0 +1,21 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo
+
+trait CboExplain[T <: AnyRef] {
+ def describeNode(node: T): String
+}
diff --git a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/CboGroup.scala b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/CboGroup.scala
new file mode 100644
index 000000000000..025e664ec207
--- /dev/null
+++ b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/CboGroup.scala
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo
+
+import io.glutenproject.cbo.memo.MemoStore
+import io.glutenproject.cbo.property.PropertySet
+
+trait CboGroup[T <: AnyRef] {
+ def id(): Int
+ def clusterKey(): CboClusterKey
+ def propSet(): PropertySet[T]
+ def self(): T
+ def nodes(store: MemoStore[T]): Iterable[CanonicalNode[T]]
+}
+
+object CboGroup {
+ def apply[T <: AnyRef](
+ cbo: Cbo[T],
+ clusterKey: CboClusterKey,
+ id: Int,
+ propSet: PropertySet[T]): CboGroup[T] = {
+ new CboGroupImpl[T](cbo, clusterKey, id, propSet)
+ }
+
+ private class CboGroupImpl[T <: AnyRef](
+ cbo: Cbo[T],
+ clusterKey: CboClusterKey,
+ override val id: Int,
+ override val propSet: PropertySet[T])
+ extends CboGroup[T] {
+ private val groupLeaf: T = cbo.planModel.newGroupLeaf(id, propSet)
+
+ override def clusterKey(): CboClusterKey = clusterKey
+ override def self(): T = groupLeaf
+ override def nodes(store: MemoStore[T]): Iterable[CanonicalNode[T]] = {
+ store.getCluster(clusterKey).nodes().filter(n => n.propSet().satisfies(propSet))
+ }
+ override def toString(): String = {
+ s"CboGroup(id=$id, clusterKey=$clusterKey, propSet=$propSet))"
+ }
+ }
+}
diff --git a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/CboNode.scala b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/CboNode.scala
new file mode 100644
index 000000000000..e5d4985c2590
--- /dev/null
+++ b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/CboNode.scala
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo
+
+import io.glutenproject.cbo.property.PropertySet
+
+trait CboNode[T <: AnyRef] {
+ def cbo(): Cbo[T]
+ def self(): T
+ def propSet(): PropertySet[T]
+}
+
+object CboNode {
+ implicit class CboNodeImplicits[T <: AnyRef](node: CboNode[T]) {
+ def isCanonical: Boolean = {
+ node.isInstanceOf[CanonicalNode[T]]
+ }
+
+ def asCanonical(): CanonicalNode[T] = {
+ node.asInstanceOf[CanonicalNode[T]]
+ }
+
+ def isGroup: Boolean = {
+ node.isInstanceOf[GroupNode[T]]
+ }
+
+ def asGroup(): GroupNode[T] = {
+ node.asInstanceOf[GroupNode[T]]
+ }
+ }
+}
+
+trait CanonicalNode[T <: AnyRef] extends CboNode[T] {
+ def childrenCount: Int
+}
+
+object CanonicalNode {
+ def apply[T <: AnyRef](cbo: Cbo[T], canonical: T): CanonicalNode[T] = {
+ assert(cbo.isCanonical(canonical))
+ val propSet = cbo.propSetsOf(canonical)
+ val children = cbo.planModel.childrenOf(canonical)
+ CanonicalNodeImpl[T](cbo, canonical, propSet, children.size)
+ }
+
+ // We put CboNode's API methods that accept mutable input in implicit definition.
+ // Do not break this rule during further development.
+ implicit class CanonicalNodeImplicits[T <: AnyRef](node: CanonicalNode[T]) {
+ def isLeaf(): Boolean = {
+ node.childrenCount == 0
+ }
+
+ def getChildrenGroups(allGroups: Int => CboGroup[T]): Seq[GroupNode[T]] = {
+ val cbo = node.cbo()
+ cbo.getChildrenGroupIds(node.self()).map(allGroups(_)).map(g => GroupNode(cbo, g))
+ }
+
+ def getChildrenGroupIds(): Seq[Int] = {
+ val cbo = node.cbo()
+ cbo.getChildrenGroupIds(node.self())
+ }
+ }
+
+ private case class CanonicalNodeImpl[T <: AnyRef](
+ cbo: Cbo[T],
+ override val self: T,
+ override val propSet: PropertySet[T],
+ override val childrenCount: Int)
+ extends CanonicalNode[T]
+}
+
+trait GroupNode[T <: AnyRef] extends CboNode[T] {
+ def groupId(): Int
+}
+
+object GroupNode {
+ def apply[T <: AnyRef](cbo: Cbo[T], group: CboGroup[T]): GroupNode[T] = {
+ GroupNodeImpl[T](cbo, group.self(), group.propSet(), group.id())
+ }
+
+ private case class GroupNodeImpl[T <: AnyRef](
+ cbo: Cbo[T],
+ override val self: T,
+ override val propSet: PropertySet[T],
+ override val groupId: Int)
+ extends GroupNode[T] {}
+
+ // We put CboNode's API methods that accept mutable input in implicit definition.
+ // Do not break this rule during further development.
+ implicit class GroupNodeImplicits[T <: AnyRef](gn: GroupNode[T]) {
+ def group(allGroups: Int => CboGroup[T]): CboGroup[T] = {
+ allGroups(gn.groupId())
+ }
+ }
+}
+
+trait InGroupNode[T <: AnyRef] {
+ def groupId: Int
+ def can: CanonicalNode[T]
+}
+
+object InGroupNode {
+ def apply[T <: AnyRef](groupId: Int, node: CanonicalNode[T]): InGroupNode[T] = {
+ InGroupNodeImpl(groupId, node)
+ }
+ private case class InGroupNodeImpl[T <: AnyRef](groupId: Int, can: CanonicalNode[T])
+ extends InGroupNode[T]
+}
+
+trait InClusterNode[T <: AnyRef] {
+ def clusterKey: CboClusterKey
+ def can: CanonicalNode[T]
+}
+
+object InClusterNode {
+ def apply[T <: AnyRef](clusterId: CboClusterKey, node: CanonicalNode[T]): InClusterNode[T] = {
+ InClusterNodeImpl(clusterId, node)
+ }
+ private case class InClusterNodeImpl[T <: AnyRef](
+ clusterKey: CboClusterKey,
+ can: CanonicalNode[T])
+ extends InClusterNode[T]
+}
diff --git a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/CboPlanner.scala b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/CboPlanner.scala
new file mode 100644
index 000000000000..fe2edb2b4639
--- /dev/null
+++ b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/CboPlanner.scala
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo
+
+import io.glutenproject.cbo.CboConfig.PlannerType
+import io.glutenproject.cbo.dp.DpPlanner
+import io.glutenproject.cbo.exaustive.ExhaustivePlanner
+import io.glutenproject.cbo.memo.MemoState
+import io.glutenproject.cbo.path.CboPath
+import io.glutenproject.cbo.property.PropertySet
+import io.glutenproject.cbo.vis.GraphvizVisualizer
+
+import scala.collection.mutable
+
+trait CboPlanner[T <: AnyRef] {
+ def plan(): T
+ def newState(): PlannerState[T]
+}
+
+object CboPlanner {
+ def apply[T <: AnyRef](
+ cbo: Cbo[T],
+ altConstraintSets: Seq[PropertySet[T]],
+ constraintSet: PropertySet[T],
+ plan: T): CboPlanner[T] = {
+ cbo.config.plannerType match {
+ case PlannerType.Exhaustive =>
+ ExhaustivePlanner(cbo, altConstraintSets, constraintSet, plan)
+ case PlannerType.Dp =>
+ DpPlanner(cbo, altConstraintSets, constraintSet, plan)
+ }
+ }
+}
+
+trait Best[T <: AnyRef] {
+ import Best._
+ def rootGroupId(): Int
+ def bestNodes(): Set[InGroupNode[T]]
+ def winnerNodes(): Set[InGroupNode[T]]
+ def costs(): InGroupNode[T] => Option[Cost]
+ def path(): KnownCostPath[T]
+}
+
+object Best {
+ def apply[T <: AnyRef](
+ cbo: Cbo[T],
+ rootGroupId: Int,
+ bestPath: KnownCostPath[T],
+ winnerNodes: Seq[InGroupNode[T]],
+ costs: InGroupNode[T] => Option[Cost]): Best[T] = {
+ val bestNodes = mutable.Set[InGroupNode[T]]()
+
+ def dfs(groupId: Int, cursor: CboPath.PathNode[T]): Unit = {
+ val can = cursor.self().asCanonical()
+ bestNodes += InGroupNode(groupId, can)
+ cursor.zipChildrenWithGroupIds().foreach {
+ case (childPathNode, childGroupId) =>
+ dfs(childGroupId, childPathNode)
+ }
+ }
+
+ dfs(rootGroupId, bestPath.cboPath.node())
+
+ val winnerNodeSet = winnerNodes.toSet
+
+ BestImpl(cbo, rootGroupId, bestPath, bestNodes.toSet, winnerNodeSet, costs)
+ }
+
+ private case class BestImpl[T <: AnyRef](
+ cbo: Cbo[T],
+ override val rootGroupId: Int,
+ override val path: KnownCostPath[T],
+ override val bestNodes: Set[InGroupNode[T]],
+ override val winnerNodes: Set[InGroupNode[T]],
+ override val costs: InGroupNode[T] => Option[Cost])
+ extends Best[T]
+
+ trait KnownCostPath[T <: AnyRef] {
+ def cboPath: CboPath[T]
+ def cost: Cost
+ }
+
+ object KnownCostPath {
+ def apply[T <: AnyRef](cbo: Cbo[T], cboPath: CboPath[T]): KnownCostPath[T] = {
+ KnownCostPathImpl(cboPath, cbo.costModel.costOf(cboPath.plan()))
+ }
+
+ def apply[T <: AnyRef](cboPath: CboPath[T], cost: Cost): KnownCostPath[T] = {
+ KnownCostPathImpl(cboPath, cost)
+ }
+
+ private case class KnownCostPathImpl[T <: AnyRef](cboPath: CboPath[T], cost: Cost)
+ extends KnownCostPath[T]
+ }
+
+ case class BestNotFoundException(message: String, cause: Exception)
+ extends RuntimeException(message, cause)
+ object BestNotFoundException {
+ def apply(message: String): BestNotFoundException = {
+ BestNotFoundException(message, null)
+ }
+ def apply(): BestNotFoundException = {
+ BestNotFoundException(null, null)
+ }
+ }
+}
+
+trait PlannerState[T <: AnyRef] {
+ def cbo(): Cbo[T]
+ def memoState(): MemoState[T]
+ def rootGroupId(): Int
+ def best(): Best[T]
+}
+
+object PlannerState {
+ def apply[T <: AnyRef](
+ cbo: Cbo[T],
+ memoState: MemoState[T],
+ rootGroupId: Int,
+ best: Best[T]): PlannerState[T] = {
+ PlannerStateImpl(cbo, memoState, rootGroupId, best)
+ }
+
+ implicit class PlannerStateImplicits[T <: AnyRef](state: PlannerState[T]) {
+ def formatGraphviz(): String = {
+ formatGraphvizWithBest()
+ }
+
+ private def formatGraphvizWithBest(): String = {
+ GraphvizVisualizer(state.cbo(), state.memoState(), state.best()).format()
+ }
+ }
+
+ private case class PlannerStateImpl[T <: AnyRef] private (
+ override val cbo: Cbo[T],
+ override val memoState: MemoState[T],
+ override val rootGroupId: Int,
+ override val best: Best[T])
+ extends PlannerState[T]
+}
diff --git a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/CostModel.scala b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/CostModel.scala
new file mode 100644
index 000000000000..74d7c8c66376
--- /dev/null
+++ b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/CostModel.scala
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo
+
+trait Cost
+
+trait CostModel[T <: AnyRef] {
+ def costOf(node: T): Cost
+ def costComparator(): Ordering[Cost]
+ def makeInfCost(): Cost
+}
diff --git a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/PlanModel.scala b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/PlanModel.scala
new file mode 100644
index 000000000000..366d1575f152
--- /dev/null
+++ b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/PlanModel.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo
+
+import io.glutenproject.cbo.property.PropertySet
+
+trait PlanModel[T <: AnyRef] {
+ // Trivial tree operations.
+ def childrenOf(node: T): Seq[T]
+ def withNewChildren(node: T, children: Seq[T]): T
+ def hashCode(node: T): Int
+ def equals(one: T, other: T): Boolean
+
+ // Group operations.
+ def newGroupLeaf(groupId: Int, propSet: PropertySet[T]): T
+ def isGroupLeaf(node: T): Boolean
+ def getGroupId(node: T): Int
+}
diff --git a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/PropertyModel.scala b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/PropertyModel.scala
new file mode 100644
index 000000000000..a260d8dd41cc
--- /dev/null
+++ b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/PropertyModel.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo
+
+import io.glutenproject.cbo.rule.CboRule
+
+// TODO Use class tags to restrict runtime user-defined class types.
+
+trait Property[T <: AnyRef] {
+ def satisfies(other: Property[T]): Boolean
+ def definition(): PropertyDef[T, _ <: Property[T]]
+}
+
+trait PropertyDef[T <: AnyRef, P <: Property[T]] {
+ def getProperty(plan: T): P
+ def getChildrenConstraints(constraint: Property[T], plan: T): Seq[P]
+}
+
+trait EnforcerRuleFactory[T <: AnyRef] {
+ def newEnforcerRules(constraint: Property[T]): Seq[CboRule[T]]
+}
+
+trait PropertyModel[T <: AnyRef] {
+ def propertyDefs: Seq[PropertyDef[T, _ <: Property[T]]]
+ def newEnforcerRuleFactory(propertyDef: PropertyDef[T, _ <: Property[T]]): EnforcerRuleFactory[T]
+}
+
+object PropertyModel {}
diff --git a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/best/BestFinder.scala b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/best/BestFinder.scala
new file mode 100644
index 000000000000..c7e39c538409
--- /dev/null
+++ b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/best/BestFinder.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo.best
+
+import io.glutenproject.cbo._
+import io.glutenproject.cbo.Best.KnownCostPath
+import io.glutenproject.cbo.dp.DpGroupAlgo
+import io.glutenproject.cbo.memo.MemoState
+
+import scala.collection.mutable
+
+trait BestFinder[T <: AnyRef] {
+ def bestOf(groupId: Int): Best[T]
+}
+
+object BestFinder {
+ def apply[T <: AnyRef](cbo: Cbo[T], memoState: MemoState[T]): BestFinder[T] = {
+ unsafe(cbo, memoState, DpGroupAlgo.Adjustment.none())
+ }
+
+ def unsafe[T <: AnyRef](
+ cbo: Cbo[T],
+ memoState: MemoState[T],
+ adjustment: DpGroupAlgo.Adjustment[T]): BestFinder[T] = {
+ new GroupBasedBestFinder[T](cbo, memoState, adjustment)
+ }
+
+ case class KnownCostGroup[T <: AnyRef](
+ nodeToCost: Map[CanonicalNode[T], KnownCostPath[T]],
+ bestNode: CanonicalNode[T]) {
+ def best(): KnownCostPath[T] = nodeToCost(bestNode)
+ }
+
+ case class KnownCostCluster[T <: AnyRef](groupToCost: Map[Int, KnownCostGroup[T]])
+
+ private[best] def newBest[T <: AnyRef](
+ cbo: Cbo[T],
+ allGroups: Seq[CboGroup[T]],
+ group: CboGroup[T],
+ groupToCosts: Map[Int, KnownCostGroup[T]]): Best[T] = {
+ val bestPath = groupToCosts(group.id()).best()
+ val bestRoot = bestPath.cboPath.node()
+ val winnerNodes = groupToCosts.map { case (id, g) => InGroupNode(id, g.bestNode) }.toSeq
+ val costsMap = mutable.Map[InGroupNode[T], Cost]()
+ groupToCosts.foreach {
+ case (gid, g) =>
+ g.nodeToCost.foreach {
+ case (n, c) =>
+ costsMap += (InGroupNode(gid, n) -> c.cost)
+ }
+ }
+ Best(cbo, group.id(), bestPath, winnerNodes, costsMap.get)
+ }
+}
diff --git a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/best/GroupBasedBestFinder.scala b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/best/GroupBasedBestFinder.scala
new file mode 100644
index 000000000000..1666d93fbd32
--- /dev/null
+++ b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/best/GroupBasedBestFinder.scala
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo.best
+
+import io.glutenproject.cbo._
+import io.glutenproject.cbo.Best.{BestNotFoundException, KnownCostPath}
+import io.glutenproject.cbo.best.BestFinder.KnownCostGroup
+import io.glutenproject.cbo.dp.{DpGroupAlgo, DpGroupAlgoDef}
+import io.glutenproject.cbo.memo.MemoState
+import io.glutenproject.cbo.path.{CboPath, PathKeySet}
+
+// The best path's each sub-path is considered optimal in its own group.
+private class GroupBasedBestFinder[T <: AnyRef](
+ cbo: Cbo[T],
+ memoState: MemoState[T],
+ adjustment: DpGroupAlgo.Adjustment[T])
+ extends BestFinder[T] {
+ import GroupBasedBestFinder._
+
+ private val allGroups = memoState.allGroups()
+
+ override def bestOf(groupId: Int): Best[T] = {
+ val group = allGroups(groupId)
+ val groupToCosts = fillBests(group)
+ if (!groupToCosts.contains(groupId)) {
+ throw BestNotFoundException(
+ s"Best path not found. Memo state (Graphviz): \n" +
+ s"${memoState.formatGraphvizWithoutBest(groupId)}")
+ }
+ BestFinder.newBest(cbo, allGroups, group, groupToCosts)
+ }
+
+ private def fillBests(group: CboGroup[T]): Map[Int, KnownCostGroup[T]] = {
+ val algoDef = new AlgoDef(cbo, memoState)
+ val solution = DpGroupAlgo.resolve(memoState, algoDef, adjustment, group)
+ val bests = allGroups.flatMap {
+ group =>
+ if (solution.isYSolved(group)) {
+ solution.solutionOfY(group).flatMap(kcg => Some(group.id() -> kcg))
+ } else {
+ None
+ }
+ }.toMap
+ bests
+ }
+}
+
+private object GroupBasedBestFinder {
+ private[best] def algoDef[T <: AnyRef](cbo: Cbo[T], memoState: MemoState[T])
+ : DpGroupAlgoDef[T, Option[KnownCostPath[T]], Option[KnownCostGroup[T]]] = {
+ new AlgoDef(cbo, memoState)
+ }
+
+ private class AlgoDef[T <: AnyRef](cbo: Cbo[T], memoState: MemoState[T])
+ extends DpGroupAlgoDef[T, Option[KnownCostPath[T]], Option[KnownCostGroup[T]]] {
+ private val allGroups = memoState.allGroups()
+ private val costComparator = cbo.costModel.costComparator()
+
+ override def solveNode(
+ ign: InGroupNode[T],
+ childrenGroupsOutput: CboGroup[T] => Option[KnownCostGroup[T]])
+ : Option[KnownCostPath[T]] = {
+ val can = ign.can
+ if (can.isLeaf()) {
+ val path = CboPath.one(cbo, PathKeySet.trivial, allGroups, can)
+ return Some(KnownCostPath(cbo, path))
+ }
+ val childrenGroups = can.getChildrenGroups(allGroups).map(gn => allGroups(gn.groupId()))
+ val maybeBestChildrenPaths: Seq[Option[CboPath[T]]] = childrenGroups.map {
+ childGroup => childrenGroupsOutput(childGroup).map(kcg => kcg.best().cboPath)
+ }
+ if (maybeBestChildrenPaths.exists(_.isEmpty)) {
+ // Node should only be solved when all children outputs exist.
+ return None
+ }
+ val bestChildrenPaths = maybeBestChildrenPaths.map(_.get)
+ Some(KnownCostPath(cbo, path.CboPath(cbo, can, bestChildrenPaths).get))
+ }
+
+ override def solveGroup(
+ group: CboGroup[T],
+ nodesOutput: InGroupNode[T] => Option[KnownCostPath[T]]): Option[KnownCostGroup[T]] = {
+ val nodes = group.nodes(memoState)
+ // Allow unsolved children nodes while solving group.
+ val flatNodesOutput =
+ nodes.flatMap(n => nodesOutput(InGroupNode(group.id(), n)).map(kcp => n -> kcp)).toMap
+
+ if (flatNodesOutput.isEmpty) {
+ return None
+ }
+ val bestPath = flatNodesOutput.values.reduce {
+ (left, right) =>
+ Ordering
+ .by((cp: KnownCostPath[T]) => cp.cost)(costComparator)
+ .min(left, right)
+ }
+ Some(KnownCostGroup(flatNodesOutput, bestPath.cboPath.node().self().asCanonical()))
+ }
+
+ override def solveNodeOnCycle(node: InGroupNode[T]): Option[KnownCostPath[T]] =
+ None
+
+ override def solveGroupOnCycle(cluster: CboGroup[T]): Option[KnownCostGroup[T]] = {
+ None
+ }
+ }
+}
diff --git a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/dp/DpClusterAlgo.scala b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/dp/DpClusterAlgo.scala
new file mode 100644
index 000000000000..fa4d3d06c98a
--- /dev/null
+++ b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/dp/DpClusterAlgo.scala
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo.dp
+
+import io.glutenproject.cbo.{CboClusterKey, InClusterNode}
+import io.glutenproject.cbo.dp.DpZipperAlgo.Solution
+import io.glutenproject.cbo.memo.MemoTable
+
+// Dynamic programming algorithm to solve problem against a single CBO cluster that can be
+// broken down to sub problems for sub clusters.
+//
+// FIXME: Code is so similar with DpGroupAlgo.
+trait DpClusterAlgoDef[T <: AnyRef, NodeOutput <: AnyRef, ClusterOutput <: AnyRef] {
+ def solveNode(
+ node: InClusterNode[T],
+ childrenClustersOutput: CboClusterKey => ClusterOutput): NodeOutput
+ def solveCluster(
+ cluster: CboClusterKey,
+ nodesOutput: InClusterNode[T] => NodeOutput): ClusterOutput
+
+ def solveNodeOnCycle(node: InClusterNode[T]): NodeOutput
+ def solveClusterOnCycle(cluster: CboClusterKey): ClusterOutput
+}
+
+object DpClusterAlgo {
+
+ trait Adjustment[T <: AnyRef] extends DpZipperAlgo.Adjustment[InClusterNode[T], CboClusterKey]
+
+ object Adjustment {
+ private class None[T <: AnyRef] extends Adjustment[T] {
+ override def exploreChildX(
+ panel: DpZipperAlgo.Adjustment.Panel[InClusterNode[T], CboClusterKey],
+ x: InClusterNode[T]): Unit = {}
+ override def exploreParentY(
+ panel: DpZipperAlgo.Adjustment.Panel[InClusterNode[T], CboClusterKey],
+ y: CboClusterKey): Unit = {}
+ override def exploreChildY(
+ panel: DpZipperAlgo.Adjustment.Panel[InClusterNode[T], CboClusterKey],
+ y: CboClusterKey): Unit = {}
+ override def exploreParentX(
+ panel: DpZipperAlgo.Adjustment.Panel[InClusterNode[T], CboClusterKey],
+ x: InClusterNode[T]): Unit = {}
+ }
+
+ def none[T <: AnyRef](): Adjustment[T] = new None[T]()
+ }
+
+ def resolve[T <: AnyRef, NodeOutput <: AnyRef, ClusterOutput <: AnyRef](
+ memoTable: MemoTable[T],
+ groupAlgoDef: DpClusterAlgoDef[T, NodeOutput, ClusterOutput],
+ adjustment: Adjustment[T],
+ cluster: CboClusterKey)
+ : Solution[InClusterNode[T], CboClusterKey, NodeOutput, ClusterOutput] = {
+ DpZipperAlgo.resolve(new ZipperAlgoDefImpl(memoTable, groupAlgoDef), adjustment, cluster)
+ }
+
+ private class ZipperAlgoDefImpl[T <: AnyRef, NodeOutput <: AnyRef, ClusterOutput <: AnyRef](
+ memoTable: MemoTable[T],
+ clusterAlgoDef: DpClusterAlgoDef[T, NodeOutput, ClusterOutput])
+ extends DpZipperAlgoDef[InClusterNode[T], CboClusterKey, NodeOutput, ClusterOutput] {
+ override def idOfX(x: InClusterNode[T]): Any = {
+ x
+ }
+
+ override def idOfY(y: CboClusterKey): Any = {
+ y
+ }
+
+ override def browseX(x: InClusterNode[T]): Iterable[CboClusterKey] = {
+ val allGroups = memoTable.allGroups()
+ x.can
+ .getChildrenGroups(allGroups)
+ .map(gn => allGroups(gn.groupId()).clusterKey())
+ }
+
+ override def browseY(y: CboClusterKey): Iterable[InClusterNode[T]] = {
+ memoTable.getCluster(y).nodes().map(n => InClusterNode(y, n))
+ }
+
+ override def solveX(x: InClusterNode[T], yOutput: CboClusterKey => ClusterOutput): NodeOutput =
+ clusterAlgoDef.solveNode(x, yOutput)
+
+ override def solveY(y: CboClusterKey, xOutput: InClusterNode[T] => NodeOutput): ClusterOutput =
+ clusterAlgoDef.solveCluster(y, xOutput)
+
+ override def solveXOnCycle(x: InClusterNode[T]): NodeOutput = clusterAlgoDef.solveNodeOnCycle(x)
+
+ override def solveYOnCycle(y: CboClusterKey): ClusterOutput =
+ clusterAlgoDef.solveClusterOnCycle(y)
+ }
+}
diff --git a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/dp/DpGroupAlgo.scala b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/dp/DpGroupAlgo.scala
new file mode 100644
index 000000000000..e85bc1c63d3c
--- /dev/null
+++ b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/dp/DpGroupAlgo.scala
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo.dp
+
+import io.glutenproject.cbo.{CboGroup, InGroupNode}
+import io.glutenproject.cbo.dp.DpZipperAlgo.Solution
+import io.glutenproject.cbo.memo.MemoState
+
+// Dynamic programming algorithm to solve problem against a single CBO group that can be
+// broken down to sub problems for sub groups.
+trait DpGroupAlgoDef[T <: AnyRef, NodeOutput <: AnyRef, GroupOutput <: AnyRef] {
+ def solveNode(node: InGroupNode[T], childrenGroupsOutput: CboGroup[T] => GroupOutput): NodeOutput
+ def solveGroup(group: CboGroup[T], nodesOutput: InGroupNode[T] => NodeOutput): GroupOutput
+
+ def solveNodeOnCycle(node: InGroupNode[T]): NodeOutput
+ def solveGroupOnCycle(cluster: CboGroup[T]): GroupOutput
+}
+
+object DpGroupAlgo {
+
+ trait Adjustment[T <: AnyRef] extends DpZipperAlgo.Adjustment[InGroupNode[T], CboGroup[T]]
+
+ object Adjustment {
+ private class None[T <: AnyRef] extends Adjustment[T] {
+ override def exploreChildX(
+ panel: DpZipperAlgo.Adjustment.Panel[InGroupNode[T], CboGroup[T]],
+ x: InGroupNode[T]): Unit = {}
+ override def exploreParentY(
+ panel: DpZipperAlgo.Adjustment.Panel[InGroupNode[T], CboGroup[T]],
+ y: CboGroup[T]): Unit = {}
+ override def exploreChildY(
+ panel: DpZipperAlgo.Adjustment.Panel[InGroupNode[T], CboGroup[T]],
+ y: CboGroup[T]): Unit = {}
+ override def exploreParentX(
+ panel: DpZipperAlgo.Adjustment.Panel[InGroupNode[T], CboGroup[T]],
+ x: InGroupNode[T]): Unit = {}
+ }
+
+ def none[T <: AnyRef](): Adjustment[T] = new None[T]()
+ }
+
+ def resolve[T <: AnyRef, NodeOutput <: AnyRef, GroupOutput <: AnyRef](
+ memoState: MemoState[T],
+ groupAlgoDef: DpGroupAlgoDef[T, NodeOutput, GroupOutput],
+ adjustment: Adjustment[T],
+ group: CboGroup[T]): Solution[InGroupNode[T], CboGroup[T], NodeOutput, GroupOutput] = {
+ DpZipperAlgo.resolve(new ZipperAlgoDefImpl(memoState, groupAlgoDef), adjustment, group)
+ }
+
+ private class ZipperAlgoDefImpl[T <: AnyRef, NodeOutput <: AnyRef, GroupOutput <: AnyRef](
+ memoState: MemoState[T],
+ groupAlgoDef: DpGroupAlgoDef[T, NodeOutput, GroupOutput])
+ extends DpZipperAlgoDef[InGroupNode[T], CboGroup[T], NodeOutput, GroupOutput] {
+ override def idOfX(x: InGroupNode[T]): Any = {
+ x
+ }
+
+ override def idOfY(y: CboGroup[T]): Any = {
+ y.id()
+ }
+
+ override def browseX(x: InGroupNode[T]): Iterable[CboGroup[T]] = {
+ val allGroups = memoState.allGroups()
+ x.can.getChildrenGroups(allGroups).map(gn => allGroups(gn.groupId()))
+ }
+
+ override def browseY(y: CboGroup[T]): Iterable[InGroupNode[T]] = {
+ y.nodes(memoState).map(can => InGroupNode(y.id(), can))
+ }
+
+ override def solveX(x: InGroupNode[T], yOutput: CboGroup[T] => GroupOutput): NodeOutput =
+ groupAlgoDef.solveNode(x, yOutput)
+
+ override def solveY(y: CboGroup[T], xOutput: InGroupNode[T] => NodeOutput): GroupOutput =
+ groupAlgoDef.solveGroup(y, xOutput)
+
+ override def solveXOnCycle(x: InGroupNode[T]): NodeOutput = groupAlgoDef.solveNodeOnCycle(x)
+
+ override def solveYOnCycle(y: CboGroup[T]): GroupOutput = groupAlgoDef.solveGroupOnCycle(y)
+ }
+}
diff --git a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/dp/DpPlanner.scala b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/dp/DpPlanner.scala
new file mode 100644
index 000000000000..86eba84ee9cf
--- /dev/null
+++ b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/dp/DpPlanner.scala
@@ -0,0 +1,200 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo.dp
+
+import io.glutenproject.cbo._
+import io.glutenproject.cbo.Best.KnownCostPath
+import io.glutenproject.cbo.best.BestFinder
+import io.glutenproject.cbo.dp.DpZipperAlgo.Adjustment.Panel
+import io.glutenproject.cbo.memo.{Memo, MemoTable}
+import io.glutenproject.cbo.path.{CboPath, PathFinder}
+import io.glutenproject.cbo.property.PropertySet
+import io.glutenproject.cbo.rule.{EnforcerRuleSet, RuleApplier, Shape}
+
+// TODO: Branch and bound pruning.
+private class DpPlanner[T <: AnyRef] private (
+ cbo: Cbo[T],
+ altConstraintSets: Seq[PropertySet[T]],
+ constraintSet: PropertySet[T],
+ plan: T)
+ extends CboPlanner[T] {
+ import DpPlanner._
+
+ private val memo = Memo.unsafe(cbo)
+ private val rules = cbo.ruleFactory.create().map(rule => RuleApplier(cbo, memo, rule))
+ private val enforcerRuleSet = EnforcerRuleSet[T](cbo, memo)
+
+ private lazy val rootGroupId: Int = {
+ memo.memorize(plan, constraintSet).id()
+ }
+
+ private lazy val best: (Best[T], KnownCostPath[T]) = {
+ altConstraintSets.foreach(propSet => memo.memorize(plan, propSet))
+ val groupId = rootGroupId
+ val memoTable = memo.table()
+ val best = findBest(memoTable, groupId)
+ (best, best.path())
+ }
+
+ override def plan(): T = {
+ best._2.cboPath.plan()
+ }
+
+ override def newState(): PlannerState[T] = {
+ val foundBest = best._1
+ PlannerState(cbo, memo.newState(), rootGroupId, foundBest)
+ }
+
+ private def findBest(memoTable: MemoTable[T], groupId: Int): Best[T] = {
+ val cKey = memoTable.allGroups()(groupId).clusterKey()
+ val algoDef = new DpExploreAlgoDef[T]
+ val adjustment = new ExploreAdjustment(cbo, memoTable, rules, enforcerRuleSet)
+ DpClusterAlgo.resolve(memoTable, algoDef, adjustment, cKey)
+ val finder = BestFinder(cbo, memoTable.newState())
+ finder.bestOf(groupId)
+ }
+}
+
+object DpPlanner {
+ def apply[T <: AnyRef](
+ cbo: Cbo[T],
+ altConstraintSets: Seq[PropertySet[T]],
+ constraintSet: PropertySet[T],
+ plan: T): CboPlanner[T] = {
+ new DpPlanner(cbo, altConstraintSets: Seq[PropertySet[T]], constraintSet, plan)
+ }
+
+ // Visited flag.
+ sealed private trait SolvedFlag
+ private case object Solved extends SolvedFlag
+
+ private class DpExploreAlgoDef[T <: AnyRef] extends DpClusterAlgoDef[T, SolvedFlag, SolvedFlag] {
+ override def solveNode(
+ node: InClusterNode[T],
+ childrenClustersOutput: CboClusterKey => SolvedFlag): SolvedFlag = Solved
+ override def solveCluster(
+ group: CboClusterKey,
+ nodesOutput: InClusterNode[T] => SolvedFlag): SolvedFlag = Solved
+ override def solveNodeOnCycle(node: InClusterNode[T]): SolvedFlag = Solved
+ override def solveClusterOnCycle(cluster: CboClusterKey): SolvedFlag = Solved
+ }
+
+ private class ExploreAdjustment[T <: AnyRef](
+ cbo: Cbo[T],
+ memoTable: MemoTable[T],
+ rules: Seq[RuleApplier[T]],
+ enforcerRuleSet: EnforcerRuleSet[T])
+ extends DpClusterAlgo.Adjustment[T] {
+ private val allGroups = memoTable.allGroups()
+ private val clusterLookup = cKey => memoTable.getCluster(cKey)
+
+ override def exploreChildX(
+ panel: Panel[InClusterNode[T], CboClusterKey],
+ x: InClusterNode[T]): Unit = {}
+ override def exploreChildY(
+ panel: Panel[InClusterNode[T], CboClusterKey],
+ y: CboClusterKey): Unit = {}
+ override def exploreParentX(
+ panel: Panel[InClusterNode[T], CboClusterKey],
+ x: InClusterNode[T]): Unit = {}
+
+ override def exploreParentY(
+ panel: Panel[InClusterNode[T], CboClusterKey],
+ cKey: CboClusterKey): Unit = {
+ memoTable.doExhaustively {
+ applyEnforcerRules(panel, cKey)
+ applyRules(panel, cKey)
+ }
+ }
+
+ private def applyRules(
+ panel: Panel[InClusterNode[T], CboClusterKey],
+ cKey: CboClusterKey): Unit = {
+ if (rules.isEmpty) {
+ return
+ }
+ val cluster = clusterLookup(cKey)
+ cluster.nodes().foreach {
+ node =>
+ val shapes = rules.map(_.shape())
+ findPaths(node, shapes)(path => rules.foreach(rule => applyRule(panel, cKey, rule, path)))
+ }
+ }
+
+ private def applyEnforcerRules(
+ panel: Panel[InClusterNode[T], CboClusterKey],
+ cKey: CboClusterKey): Unit = {
+ val cluster = clusterLookup(cKey)
+ cKey.propSets(memoTable).foreach {
+ constraintSet =>
+ val enforcerRules = enforcerRuleSet.rulesOf(constraintSet)
+ if (enforcerRules.nonEmpty) {
+ val shapes = enforcerRules.map(_.shape())
+ cluster.nodes().foreach {
+ node =>
+ findPaths(node, shapes)(
+ path => enforcerRules.foreach(rule => applyRule(panel, cKey, rule, path)))
+ }
+ }
+ }
+ }
+
+ private def findPaths(canonical: CanonicalNode[T], shapes: Seq[Shape[T]])(
+ onFound: CboPath[T] => Unit): Unit = {
+ val finder = shapes
+ .foldLeft(
+ PathFinder
+ .builder(cbo, memoTable)) {
+ case (builder, shape) =>
+ builder.output(shape.wizard())
+ }
+ .build()
+ finder.find(canonical).foreach(path => onFound(path))
+ }
+
+ private def applyRule(
+ panel: Panel[InClusterNode[T], CboClusterKey],
+ thisClusterKey: CboClusterKey,
+ rule: RuleApplier[T],
+ path: CboPath[T]): Unit = {
+ val probe = memoTable.probe()
+ rule.apply(path)
+ val diff = probe.toDiff()
+ val changedClusters = diff.changedClusters()
+ if (changedClusters.isEmpty) {
+ return
+ }
+
+ // One or more cluster changed. If they're not the current cluster, we should
+ // withdraw the DP results for them to trigger re-computation. Since
+ // changed cluster (may created new groups, may added new nodes) could expand the
+ // search spaces again.
+
+ changedClusters.foreach {
+ case cKey if cKey == thisClusterKey =>
+ // This cluster has been changed. This cluster is being solved so we
+ // don't have to invalidate.
+ case cKey =>
+ // Changes happened on another cluster. Invalidate solution for the cluster
+ // To trigger re-computation.
+ panel.invalidateYSolution(cKey)
+ }
+ }
+ }
+
+ private object ExploreAdjustment {}
+}
diff --git a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/dp/DpZipperAlgo.scala b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/dp/DpZipperAlgo.scala
new file mode 100644
index 000000000000..b1b0cd23111f
--- /dev/null
+++ b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/dp/DpZipperAlgo.scala
@@ -0,0 +1,656 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo.dp
+
+import io.glutenproject.cbo.util.CycleDetector
+
+import scala.collection.mutable
+
+/**
+ * Dynamic programming algorithm to solve problem that can be broken down to sub-problems on 2
+ * individual different element types.
+ *
+ * The elements types here are X, Y. Programming starts from Y, respectively traverses down to X, Y,
+ * X..., util reaching to a leaf.
+ *
+ * Two major issues are handled by the base algo internally:
+ *
+ * 1. Cycle exclusion:
+ *
+ * The algo will withdraw the recursive call when found a cycle. Cycle is detected via the
+ * comparison function passed by DpZipperAlgoDef#idOfX and DpZipperAlgoDef#idOfY. When a cycle is
+ * found, the element that just created cycle (assume it is A) will be forced to return a
+ * CycleMemory(A), then nodes on the whole recursive tree will therefore return their results with
+ * CycleMemory(A). This means their results are incomplete by having the cyclic paths excluded.
+ * Whether a path is "cyclic" is subjective: a child path can be cyclic for some parent nodes, but
+ * not for some other parent nodes. So the incomplete results will not be memorized to solution
+ * builder.
+ *
+ * However, once CycleMemory(A) is returned back to element A, A could be safely removed from the
+ * cycle memory. This means the cycle is successfully enclosed and when the call tree continues
+ * returning, there will be no cycles. Then the further results can be cached to solution builder.
+ *
+ * The above is a simplified example. The real cycle memory consists of a set for all cyclic nodes.
+ * Only when the set gets cleared, then the current call can be considered cycle-free.
+ *
+ * 2. Branch invalidation:
+ *
+ * Since it can be required that the algo implementation tends to re-compute the already solved
+ * elements, a #invalidate API is added in the adjustment panel.
+ *
+ * The invalidation is implemented in this way: each element would log its parent as its
+ * back-dependency after it gets solved. For example, A has 3 children (B, C, D), after B, C were
+ * solved respectively, A is added to B and C's back-dependency list. Then solution builder would be
+ * aware of that A depends on B, as well as A depends on C. After this operation, Algorithm would
+ * call the user-defined adjustment to allow caller invalidate some elements. If B is getting
+ * invalidated, the algo will remove the saved solution of B, then find all back-dependencies of B,
+ * then remove the saved results (if exist) of them, then find all back-dependencies of all the
+ * back-dependencies of B, ... In this case, we just have 1 layer of recursive so only relation (B
+ * -> A) gets removed. After A successfully solved D, the algo will backtrack all the already solved
+ * children (B, C, D) to see if the previously registered back-dependencies (B -> A, C -> A, D -> A)
+ * are still alive. In the case we only have (C -> A, D -> A) remaining, thus the algo will try to
+ * recompute B. If during the procedure C or D or some of their children get invalidated again, then
+ * keep looping until all the children are successfully solved and all the back-dependencies
+ * survived.
+ *
+ * One of the possible corner cases is, for example, when B just gets solved, and is getting
+ * adjusted, during which one of B's sub-tree gets invalidated. Since we apply the adjustment right
+ * after the back-dependency (B -> A) is established, algo can still recognize (B -> A)'s removal
+ * and recompute B. So this corner case is also handled correctly. The above is a simplified example
+ * either. The real program will handle the invalidation for any depth of recursions.
+ */
+trait DpZipperAlgoDef[X <: AnyRef, Y <: AnyRef, XOutput <: AnyRef, YOutput <: AnyRef] {
+ def idOfX(x: X): Any
+ def idOfY(y: Y): Any
+
+ def browseX(x: X): Iterable[Y]
+ def browseY(y: Y): Iterable[X]
+
+ def solveX(x: X, yOutput: Y => YOutput): XOutput
+ def solveY(y: Y, xOutput: X => XOutput): YOutput
+
+ def solveXOnCycle(x: X): XOutput
+ def solveYOnCycle(y: Y): YOutput
+}
+
+object DpZipperAlgo {
+ def resolve[X <: AnyRef, Y <: AnyRef, XOutput <: AnyRef, YOutput <: AnyRef](
+ algoDef: DpZipperAlgoDef[X, Y, XOutput, YOutput],
+ adjustment: Adjustment[X, Y],
+ root: Y): Solution[X, Y, XOutput, YOutput] = {
+ val algo = new DpZipperAlgoResolver(algoDef, adjustment)
+ algo.resolve(root)
+ }
+
+ trait Adjustment[X <: AnyRef, Y <: AnyRef] {
+ import Adjustment._
+ def exploreChildX(panel: Panel[X, Y], x: X): Unit
+ def exploreParentY(panel: Panel[X, Y], y: Y): Unit
+ def exploreChildY(panel: Panel[X, Y], y: Y): Unit
+ def exploreParentX(panel: Panel[X, Y], x: X): Unit
+ }
+
+ object Adjustment {
+ trait Panel[X <: AnyRef, Y <: AnyRef] {
+ def invalidateXSolution(x: X): Unit
+ def invalidateYSolution(y: Y): Unit
+ }
+
+ object Panel {
+ def apply[X <: AnyRef, Y <: AnyRef, XOutput <: AnyRef, YOutput <: AnyRef](
+ sBuilder: Solution.Builder[X, Y, XOutput, YOutput]): Panel[X, Y] =
+ new PanelImpl[X, Y, XOutput, YOutput](sBuilder)
+
+ private class PanelImpl[X <: AnyRef, Y <: AnyRef, XOutput <: AnyRef, YOutput <: AnyRef](
+ sBuilder: Solution.Builder[X, Y, XOutput, YOutput])
+ extends Panel[X, Y] {
+ override def invalidateXSolution(x: X): Unit = {
+ if (!sBuilder.isXResolved(x)) {
+ return
+ }
+ sBuilder.invalidateXSolution(x)
+ }
+
+ override def invalidateYSolution(y: Y): Unit = {
+ if (!sBuilder.isYResolved(y)) {
+ return
+ }
+ sBuilder.invalidateYSolution(y)
+ }
+ }
+ }
+
+ private class None[X <: AnyRef, Y <: AnyRef] extends Adjustment[X, Y] {
+ override def exploreChildX(panel: Panel[X, Y], x: X): Unit = {}
+ override def exploreParentY(panel: Panel[X, Y], y: Y): Unit = {}
+ override def exploreChildY(panel: Panel[X, Y], y: Y): Unit = {}
+ override def exploreParentX(panel: Panel[X, Y], x: X): Unit = {}
+ }
+ def none[X <: AnyRef, Y <: AnyRef](): Adjustment[X, Y] = new None()
+ }
+
+ private class DpZipperAlgoResolver[
+ X <: AnyRef,
+ Y <: AnyRef,
+ XOutput <: AnyRef,
+ YOutput <: AnyRef](
+ algoDef: DpZipperAlgoDef[X, Y, XOutput, YOutput],
+ adjustment: Adjustment[X, Y]) {
+ import DpZipperAlgoResolver._
+
+ private val sBuilder: Solution.Builder[X, Y, XOutput, YOutput] =
+ Solution.builder[X, Y, XOutput, YOutput](algoDef)
+ private val adjustmentPanel = Adjustment.Panel[X, Y, XOutput, YOutput](sBuilder)
+
+ def resolve(root: Y): Solution[X, Y, XOutput, YOutput] = {
+ val xCycleDetector =
+ CycleDetector[X]((one, other) => algoDef.idOfX(one) == algoDef.idOfX(other))
+ val yCycleDetector =
+ CycleDetector[Y]((one, other) => algoDef.idOfY(one) == algoDef.idOfY(other))
+ solveYRec(root, xCycleDetector, yCycleDetector)
+ sBuilder.build()
+ }
+
+ private def solveYRec(
+ thisY: Y,
+ xCycleDetector: CycleDetector[X],
+ yCycleDetector: CycleDetector[Y]): CycleAwareYOutput[X, Y, XOutput, YOutput] = {
+ if (yCycleDetector.contains(thisY)) {
+ return CycleAwareYOutput(algoDef.solveYOnCycle(thisY), CycleMemory(algoDef).addY(thisY))
+ }
+ val newYCycleDetector = yCycleDetector.append(thisY)
+ if (sBuilder.isYResolved(thisY)) {
+ // The same Y was already solved by previous traversals before bumping into
+ // this position.
+ return CycleAwareYOutput(sBuilder.getYSolution(thisY), CycleMemory(algoDef))
+ }
+
+ val cyclicXs: mutable.Set[XKey[X, Y, XOutput, YOutput]] = mutable.Set()
+ val cyclicYs: mutable.Set[YKey[X, Y, XOutput, YOutput]] = mutable.Set()
+
+ val xSolutions: mutable.Map[XKey[X, Y, XOutput, YOutput], XOutput] = mutable.Map()
+
+ def loop(): Unit = {
+ while (true) {
+ val xKeys: Set[XKey[X, Y, XOutput, YOutput]] =
+ algoDef.browseY(thisY).map(algoDef.keyOfX(_)).toSet
+
+ val xCount = xKeys.size
+ if (xCount == xSolutions.size) {
+ // We got enough children solutions.
+ return
+ }
+
+ xKeys.filterNot(xKey => xSolutions.contains(xKey)).foreach {
+ childXKey =>
+ val xOutputs = solveXRec(childXKey.x, xCycleDetector, newYCycleDetector)
+ val cm = xOutputs.cycleMemory()
+ cyclicXs ++= cm.cyclicXs
+ cyclicYs ++= cm.cyclicYs
+ sBuilder.addYAsBackDependencyOfX(thisY, childXKey.x)
+ xSolutions += childXKey -> xOutputs.output()
+ // Try applying adjustment
+ // to see if algo caller likes to add some Xs or to invalidate
+ // some of the registered solutions.
+ adjustment.exploreChildX(adjustmentPanel, childXKey.x)
+ }
+ adjustment.exploreParentY(adjustmentPanel, thisY)
+ // If an adjustment (this adjustment or children's) just invalidated one or more
+ // children of this element's solutions, the children's keys would be removed from
+ // back-dependency list. We do a test here to trigger re-computation if some children
+ // do get invalidated.
+ xSolutions.keySet.foreach {
+ childXKey =>
+ if (!sBuilder.yHasDependency(thisY, childXKey.x)) {
+ xSolutions -= childXKey
+ }
+ }
+ }
+ }
+
+ loop()
+
+ // Remove this element from cycle memory, if it's in it.
+ cyclicYs -= algoDef.keyOfY(thisY)
+
+ val cycleMemory = CycleMemory(algoDef, cyclicXs.toSet, cyclicYs.toSet)
+
+ val ySolution =
+ algoDef.solveY(thisY, x => xSolutions(XKey(algoDef, x)))
+
+ val cycleAware = CycleAwareYOutput(ySolution, cycleMemory)
+ if (!cycleMemory.isOnCycle()) {
+ // We only cache the solution if this element is not on any cycles.
+ sBuilder.addYSolution(thisY, ySolution)
+ }
+ cycleAware
+ }
+
+ private def solveXRec(
+ thisX: X,
+ xCycleDetector: CycleDetector[X],
+ yCycleDetector: CycleDetector[Y]): CycleAwareXOutput[X, Y, XOutput, YOutput] = {
+ if (xCycleDetector.contains(thisX)) {
+ return CycleAwareXOutput(algoDef.solveXOnCycle(thisX), CycleMemory(algoDef).addX(thisX))
+ }
+ val newXCycleDetector = xCycleDetector.append(thisX)
+ if (sBuilder.isXResolved(thisX)) {
+ // The same X was already solved by previous traversals before bumping into
+ // this position.
+ return CycleAwareXOutput(sBuilder.getXSolution(thisX), CycleMemory(algoDef))
+ }
+
+ val cyclicXs: mutable.Set[XKey[X, Y, XOutput, YOutput]] = mutable.Set()
+ val cyclicYs: mutable.Set[YKey[X, Y, XOutput, YOutput]] = mutable.Set()
+
+ val ySolutions: mutable.Map[YKey[X, Y, XOutput, YOutput], YOutput] = mutable.Map()
+
+ def loop(): Unit = {
+ while (true) {
+ val yKeys: Set[YKey[X, Y, XOutput, YOutput]] =
+ algoDef.browseX(thisX).map(algoDef.keyOfY(_)).toSet
+
+ val yCount = yKeys.size
+ if (yCount == ySolutions.size) {
+ // We got enough children solutions.
+ return
+ }
+
+ yKeys.filterNot(yKey => ySolutions.contains(yKey)).foreach {
+ childYKey =>
+ val yOutputs = solveYRec(childYKey.y, newXCycleDetector, yCycleDetector)
+ val cm = yOutputs.cycleMemory()
+ cyclicXs ++= cm.cyclicXs
+ cyclicYs ++= cm.cyclicYs
+ sBuilder.addXAsBackDependencyOfY(thisX, childYKey.y)
+ ySolutions += childYKey -> yOutputs.output()
+ // Try applying adjustment
+ // to see if algo caller likes to add some Ys or to invalidate
+ // some of the registered solutions.
+ adjustment.exploreChildY(adjustmentPanel, childYKey.y)
+ }
+ adjustment.exploreParentX(adjustmentPanel, thisX)
+ // If an adjustment (this adjustment or children's) just invalidated one or more
+ // children of this element's solutions, the children's keys would be removed from
+ // back-dependency list. We do a test here to trigger re-computation if some children
+ // do get invalidated.
+ ySolutions.keySet.foreach {
+ childYKey =>
+ if (!sBuilder.xHasDependency(thisX, childYKey.y)) {
+ ySolutions -= childYKey
+ }
+ }
+ }
+ }
+
+ loop()
+
+ // Remove this element from cycle memory, if it's in it.
+ cyclicXs -= algoDef.keyOfX(thisX)
+
+ val cycleMemory = CycleMemory(algoDef, cyclicXs.toSet, cyclicYs.toSet)
+
+ val xSolution =
+ algoDef.solveX(thisX, y => ySolutions(YKey(algoDef, y)))
+
+ val cycleAware = CycleAwareXOutput(xSolution, cycleMemory)
+ if (!cycleMemory.isOnCycle()) {
+ // We only cache the solution if this element is not on any cycles.
+ sBuilder.addXSolution(thisX, xSolution)
+ }
+ cycleAware
+ }
+
+ }
+
+ private object DpZipperAlgoResolver {
+ private trait CycleAwareXOutput[
+ X <: AnyRef,
+ Y <: AnyRef,
+ XOutput <: AnyRef,
+ YOutput <: AnyRef] {
+ def output(): XOutput
+ def cycleMemory(): CycleMemory[X, Y, XOutput, YOutput]
+ }
+
+ private object CycleAwareXOutput {
+ def apply[X <: AnyRef, Y <: AnyRef, XOutput <: AnyRef, YOutput <: AnyRef](
+ output: XOutput,
+ cycleMemory: CycleMemory[X, Y, XOutput, YOutput])
+ : CycleAwareXOutput[X, Y, XOutput, YOutput] = {
+ new CycleAwareXOutputImpl(output, cycleMemory)
+ }
+
+ private class CycleAwareXOutputImpl[
+ X <: AnyRef,
+ Y <: AnyRef,
+ XOutput <: AnyRef,
+ YOutput <: AnyRef](
+ override val output: XOutput,
+ override val cycleMemory: CycleMemory[X, Y, XOutput, YOutput])
+ extends CycleAwareXOutput[X, Y, XOutput, YOutput]
+ }
+
+ private trait CycleAwareYOutput[
+ X <: AnyRef,
+ Y <: AnyRef,
+ XOutput <: AnyRef,
+ YOutput <: AnyRef] {
+ def output(): YOutput
+ def cycleMemory(): CycleMemory[X, Y, XOutput, YOutput]
+ }
+
+ private object CycleAwareYOutput {
+ def apply[X <: AnyRef, Y <: AnyRef, XOutput <: AnyRef, YOutput <: AnyRef](
+ output: YOutput,
+ cycleMemory: CycleMemory[X, Y, XOutput, YOutput])
+ : CycleAwareYOutput[X, Y, XOutput, YOutput] = {
+ new CycleAwareYOutputImpl(output, cycleMemory)
+ }
+
+ private class CycleAwareYOutputImpl[
+ X <: AnyRef,
+ Y <: AnyRef,
+ XOutput <: AnyRef,
+ YOutput <: AnyRef](
+ override val output: YOutput,
+ override val cycleMemory: CycleMemory[X, Y, XOutput, YOutput])
+ extends CycleAwareYOutput[X, Y, XOutput, YOutput]
+ }
+
+ private trait CycleMemory[X <: AnyRef, Y <: AnyRef, XOutput <: AnyRef, YOutput <: AnyRef] {
+ def cyclicXs: Set[XKey[X, Y, XOutput, YOutput]]
+ def cyclicYs: Set[YKey[X, Y, XOutput, YOutput]]
+ def addX(x: X): CycleMemory[X, Y, XOutput, YOutput]
+ def addY(y: Y): CycleMemory[X, Y, XOutput, YOutput]
+ def removeX(x: X): CycleMemory[X, Y, XOutput, YOutput]
+ def removeY(y: Y): CycleMemory[X, Y, XOutput, YOutput]
+ def isOnCycle(): Boolean
+ }
+
+ private object CycleMemory {
+ def apply[X <: AnyRef, Y <: AnyRef, XOutput <: AnyRef, YOutput <: AnyRef](
+ algoDef: DpZipperAlgoDef[X, Y, XOutput, YOutput]): CycleMemory[X, Y, XOutput, YOutput] = {
+ new CycleMemoryImpl(algoDef, Set(), Set())
+ }
+
+ def apply[X <: AnyRef, Y <: AnyRef, XOutput <: AnyRef, YOutput <: AnyRef](
+ algoDef: DpZipperAlgoDef[X, Y, XOutput, YOutput],
+ cyclicXs: Set[XKey[X, Y, XOutput, YOutput]],
+ cyclicYs: Set[YKey[X, Y, XOutput, YOutput]]): CycleMemory[X, Y, XOutput, YOutput] = {
+ new CycleMemoryImpl(algoDef, cyclicXs, cyclicYs)
+ }
+
+ private class CycleMemoryImpl[X <: AnyRef, Y <: AnyRef, XOutput <: AnyRef, YOutput <: AnyRef](
+ algoDef: DpZipperAlgoDef[X, Y, XOutput, YOutput],
+ override val cyclicXs: Set[XKey[X, Y, XOutput, YOutput]],
+ override val cyclicYs: Set[YKey[X, Y, XOutput, YOutput]])
+ extends CycleMemory[X, Y, XOutput, YOutput] {
+ override def addX(x: X): CycleMemory[X, Y, XOutput, YOutput] = new CycleMemoryImpl(
+ algoDef,
+ cyclicXs + algoDef.keyOfX(x),
+ cyclicYs
+ )
+ override def addY(y: Y): CycleMemory[X, Y, XOutput, YOutput] = new CycleMemoryImpl(
+ algoDef,
+ cyclicXs,
+ cyclicYs + algoDef.keyOfY(y)
+ )
+ override def removeX(x: X): CycleMemory[X, Y, XOutput, YOutput] = new CycleMemoryImpl(
+ algoDef,
+ cyclicXs - algoDef.keyOfX(x),
+ cyclicYs
+ )
+ override def removeY(y: Y): CycleMemory[X, Y, XOutput, YOutput] = new CycleMemoryImpl(
+ algoDef,
+ cyclicXs,
+ cyclicYs - algoDef.keyOfY(y)
+ )
+ override def isOnCycle(): Boolean = cyclicXs.nonEmpty || cyclicYs.nonEmpty
+ }
+ }
+ }
+
+ trait Solution[X <: AnyRef, Y <: AnyRef, XOutput <: AnyRef, YOutput <: AnyRef] {
+ def isXSolved(x: X): Boolean
+ def isYSolved(y: Y): Boolean
+ def solutionOfX(x: X): XOutput
+ def solutionOfY(y: Y): YOutput
+ }
+
+ private object Solution {
+ private case class SolutionImpl[X <: AnyRef, Y <: AnyRef, XOutput <: AnyRef, YOutput <: AnyRef](
+ algoDef: DpZipperAlgoDef[X, Y, XOutput, YOutput],
+ xSolutions: Map[XKey[X, Y, XOutput, YOutput], XOutput],
+ ySolutions: Map[YKey[X, Y, XOutput, YOutput], YOutput])
+ extends Solution[X, Y, XOutput, YOutput] {
+ override def isXSolved(x: X): Boolean = xSolutions.contains(algoDef.keyOfX(x))
+ override def isYSolved(y: Y): Boolean = ySolutions.contains(algoDef.keyOfY(y))
+ override def solutionOfX(x: X): XOutput = xSolutions(algoDef.keyOfX(x))
+ override def solutionOfY(y: Y): YOutput = ySolutions(algoDef.keyOfY(y))
+ }
+
+ def builder[X <: AnyRef, Y <: AnyRef, XOutput <: AnyRef, YOutput <: AnyRef](
+ algoDef: DpZipperAlgoDef[X, Y, XOutput, YOutput]): Builder[X, Y, XOutput, YOutput] = {
+ Builder[X, Y, XOutput, YOutput](algoDef)
+ }
+
+ class Builder[X <: AnyRef, Y <: AnyRef, XOutput <: AnyRef, YOutput <: AnyRef] private (
+ algoDef: DpZipperAlgoDef[X, Y, XOutput, YOutput]) {
+
+ // Store the persisted solved elements. Only if a solution doesn't pertain to
+ // any cycles then it will be stored there.
+ private val xSolutions = mutable.Map[XKey[X, Y, XOutput, YOutput], XOutput]()
+ private val ySolutions = mutable.Map[YKey[X, Y, XOutput, YOutput], YOutput]()
+
+ private val xBackDependencies =
+ mutable.Map[XKey[X, Y, XOutput, YOutput], mutable.Set[YKey[X, Y, XOutput, YOutput]]]()
+ private val yBackDependencies =
+ mutable.Map[YKey[X, Y, XOutput, YOutput], mutable.Set[XKey[X, Y, XOutput, YOutput]]]()
+
+ def invalidateXSolution(x: X): Unit = {
+ val xKey = algoDef.keyOfX(x)
+ invalidateXSolution0(xKey)
+ }
+
+ private def invalidateXSolution0(xKey: XKey[X, Y, XOutput, YOutput]): Unit = {
+ assert(xSolutions.contains(xKey))
+ xSolutions -= xKey
+ if (!xBackDependencies.contains(xKey)) {
+ return
+ }
+ val backYs = xBackDependencies(xKey)
+ backYs.toList.foreach {
+ y =>
+ if (isYResolved0(y)) {
+ invalidateYSolution0(y)
+ }
+ backYs -= y
+ }
+ // Clear x-key from the back dependency table. This will help the algorithm control
+ // re-computation after this x gets invalidated.
+ xBackDependencies -= xKey
+ }
+
+ def invalidateYSolution(y: Y): Unit = {
+ val yKey = algoDef.keyOfY(y)
+ invalidateYSolution0(yKey)
+ }
+
+ private def invalidateYSolution0(yKey: YKey[X, Y, XOutput, YOutput]): Unit = {
+ assert(ySolutions.contains(yKey))
+ ySolutions -= yKey
+ if (!yBackDependencies.contains(yKey)) {
+ return
+ }
+ val backXs = yBackDependencies(yKey)
+ backXs.toList.foreach {
+ x =>
+ if (isXResolved0(x)) {
+ invalidateXSolution0(x)
+ }
+ backXs -= x
+ }
+ // Clear y-key from the back dependency table. This will help the algorithm control
+ // re-computation after this y gets invalidated.
+ yBackDependencies -= yKey
+ }
+
+ def isXResolved(x: X): Boolean = {
+ val xKey = algoDef.keyOfX(x)
+ isXResolved0(xKey)
+ }
+
+ private def isXResolved0(xKey: XKey[X, Y, XOutput, YOutput]): Boolean = {
+ xSolutions.contains(xKey)
+ }
+
+ def isYResolved(y: Y): Boolean = {
+ val yKey = algoDef.keyOfY(y)
+ ySolutions.contains(yKey)
+ }
+
+ private def isYResolved0(yKey: YKey[X, Y, XOutput, YOutput]): Boolean = {
+ ySolutions.contains(yKey)
+ }
+
+ def getXSolution(x: X): XOutput = {
+ val xKey = algoDef.keyOfX(x)
+ assert(xSolutions.contains(xKey))
+ xSolutions(xKey)
+ }
+
+ def getYSolution(y: Y): YOutput = {
+ val yKey = algoDef.keyOfY(y)
+ assert(ySolutions.contains(yKey))
+ ySolutions(yKey)
+ }
+
+ def addXSolution(x: X, xSolution: XOutput): Unit = {
+ val xKey = algoDef.keyOfX(x)
+ assert(!xSolutions.contains(xKey))
+ xSolutions += xKey -> xSolution
+ }
+
+ def addYSolution(y: Y, ySolution: YOutput): Unit = {
+ val yKey = algoDef.keyOfY(y)
+ assert(!ySolutions.contains(yKey))
+ ySolutions += yKey -> ySolution
+ }
+
+ def addXAsBackDependencyOfY(x: X, dependency: Y): Unit = {
+ val xKey = algoDef.keyOfX(x)
+ val yKey = algoDef.keyOfY(dependency)
+ yBackDependencies.getOrElseUpdate(yKey, mutable.Set()) += xKey
+ }
+
+ def addYAsBackDependencyOfX(y: Y, dependency: X): Unit = {
+ val yKey = algoDef.keyOfY(y)
+ val xKey = algoDef.keyOfX(dependency)
+ xBackDependencies.getOrElseUpdate(xKey, mutable.Set()) += yKey
+ }
+
+ def xHasDependency(x: X, y: Y): Boolean = {
+ val xKey = algoDef.keyOfX(x)
+ val yKey = algoDef.keyOfY(y)
+ yBackDependencies.get(yKey).exists(_.contains(xKey))
+ }
+
+ def yHasDependency(y: Y, x: X): Boolean = {
+ val yKey = algoDef.keyOfY(y)
+ val xKey = algoDef.keyOfX(x)
+ xBackDependencies.get(xKey).exists(_.contains(yKey))
+ }
+
+ def build(): Solution[X, Y, XOutput, YOutput] = {
+ SolutionImpl(
+ algoDef,
+ xSolutions.toMap,
+ ySolutions.toMap
+ )
+ }
+ }
+
+ private object Builder {
+ def apply[X <: AnyRef, Y <: AnyRef, XOutput <: AnyRef, YOutput <: AnyRef](
+ algoDef: DpZipperAlgoDef[X, Y, XOutput, YOutput]): Builder[X, Y, XOutput, YOutput] = {
+ new Builder[X, Y, XOutput, YOutput](algoDef)
+ }
+ }
+ }
+
+ class XKey[X <: AnyRef, Y <: AnyRef, XOutput <: AnyRef, YOutput <: AnyRef] private (
+ algoDef: DpZipperAlgoDef[X, Y, XOutput, YOutput],
+ val x: X) {
+ private val id = algoDef.idOfX(x)
+ override def hashCode(): Int = id.hashCode()
+ override def equals(obj: Any): Boolean = {
+ obj match {
+ case other: XKey[X, Y, XOutput, YOutput] => id == other.id
+ case _ => false
+ }
+ }
+ override def toString: String = x.toString
+ }
+
+ private object XKey {
+ // Keep argument "ele" although it is unused. To give compiler type hint.
+ def apply[X <: AnyRef, Y <: AnyRef, XOutput <: AnyRef, YOutput <: AnyRef](
+ algoDef: DpZipperAlgoDef[X, Y, XOutput, YOutput],
+ x: X): XKey[X, Y, XOutput, YOutput] = {
+ new XKey[X, Y, XOutput, YOutput](algoDef, x)
+ }
+ }
+
+ class YKey[X <: AnyRef, Y <: AnyRef, XOutput <: AnyRef, YOutput <: AnyRef] private (
+ algoDef: DpZipperAlgoDef[X, Y, XOutput, YOutput],
+ val y: Y) {
+ private val id = algoDef.idOfY(y)
+ override def hashCode(): Int = id.hashCode()
+ override def equals(obj: Any): Boolean = {
+ obj match {
+ case other: YKey[X, Y, XOutput, YOutput] => id == other.id
+ case _ => false
+ }
+ }
+ override def toString: String = y.toString
+ }
+
+ private object YKey {
+ // Keep argument "ele" although it is unused. To give compiler type hint.
+ def apply[X <: AnyRef, Y <: AnyRef, XOutput <: AnyRef, YOutput <: AnyRef](
+ algoDef: DpZipperAlgoDef[X, Y, XOutput, YOutput],
+ y: Y): YKey[X, Y, XOutput, YOutput] = {
+ new YKey[X, Y, XOutput, YOutput](algoDef, y)
+ }
+ }
+
+ implicit class DpZipperAlgoDefImplicits[
+ X <: AnyRef,
+ Y <: AnyRef,
+ XOutput <: AnyRef,
+ YOutput <: AnyRef](algoDef: DpZipperAlgoDef[X, Y, XOutput, YOutput]) {
+
+ def keyOfX(x: X): XKey[X, Y, XOutput, YOutput] = {
+ XKey(algoDef, x)
+ }
+
+ def keyOfY(y: Y): YKey[X, Y, XOutput, YOutput] = {
+ YKey(algoDef, y)
+ }
+ }
+}
diff --git a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/exaustive/ExhaustivePlanner.scala b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/exaustive/ExhaustivePlanner.scala
new file mode 100644
index 000000000000..29666fb40ed1
--- /dev/null
+++ b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/exaustive/ExhaustivePlanner.scala
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo.exaustive
+
+import io.glutenproject.cbo._
+import io.glutenproject.cbo.Best.KnownCostPath
+import io.glutenproject.cbo.best.BestFinder
+import io.glutenproject.cbo.exaustive.ExhaustivePlanner.ExhaustiveExplorer
+import io.glutenproject.cbo.memo.{Memo, MemoState}
+import io.glutenproject.cbo.path._
+import io.glutenproject.cbo.property.PropertySet
+import io.glutenproject.cbo.rule.{EnforcerRuleSet, RuleApplier, Shape}
+
+private class ExhaustivePlanner[T <: AnyRef] private (
+ cbo: Cbo[T],
+ altConstraintSets: Seq[PropertySet[T]],
+ constraintSet: PropertySet[T],
+ plan: T)
+ extends CboPlanner[T] {
+ private val memo = Memo(cbo)
+ private val rules = cbo.ruleFactory.create().map(rule => RuleApplier(cbo, memo, rule))
+ private val enforcerRuleSet = EnforcerRuleSet[T](cbo, memo)
+
+ private lazy val rootGroupId: Int = {
+ memo.memorize(plan, constraintSet).id()
+ }
+
+ private lazy val best: (Best[T], KnownCostPath[T]) = {
+ altConstraintSets.foreach(propSet => memo.memorize(plan, propSet))
+ val groupId = rootGroupId
+ explore()
+ val memoState = memo.newState()
+ val best = findBest(memoState, groupId)
+ (best, best.path())
+ }
+
+ override def plan(): T = {
+ best._2.cboPath.plan()
+ }
+
+ override def newState(): PlannerState[T] = {
+ val foundBest = best._1
+ PlannerState(cbo, memo.newState(), rootGroupId, foundBest)
+ }
+
+ private def explore(): Unit = {
+ // TODO1: Prune paths within cost threshold
+ // ~~ TODO2: Use partial-canonical paths to reduce search space ~~
+ memo.doExhaustively {
+ val explorer = new ExhaustiveExplorer(cbo, memo.newState(), rules, enforcerRuleSet)
+ explorer.explore()
+ }
+ }
+
+ private def findBest(memoState: MemoState[T], groupId: Int): Best[T] = {
+ BestFinder(cbo, memoState).bestOf(groupId)
+ }
+}
+
+object ExhaustivePlanner {
+ def apply[T <: AnyRef](
+ cbo: Cbo[T],
+ altConstraintSets: Seq[PropertySet[T]],
+ constraintSet: PropertySet[T],
+ plan: T): CboPlanner[T] = {
+ new ExhaustivePlanner(cbo, altConstraintSets, constraintSet, plan)
+ }
+
+ private class ExhaustiveExplorer[T <: AnyRef](
+ cbo: Cbo[T],
+ memoState: MemoState[T],
+ rules: Seq[RuleApplier[T]],
+ enforcerRuleSet: EnforcerRuleSet[T]) {
+ private val allClusters = memoState.allClusters()
+ private val allGroups = memoState.allGroups()
+
+ def explore(): Unit = {
+ // TODO: ONLY APPLY RULES ON ALTERED GROUPS (and close parents)
+ applyEnforcerRules()
+ applyRules()
+ }
+
+ private def findPaths(canonical: CanonicalNode[T], shapes: Seq[Shape[T]])(
+ onFound: CboPath[T] => Unit): Unit = {
+ val finder = shapes
+ .foldLeft(
+ PathFinder
+ .builder(cbo, memoState)) {
+ case (builder, shape) =>
+ builder.output(shape.wizard())
+ }
+ .build()
+ finder.find(canonical).foreach(path => onFound(path))
+ }
+
+ private def applyRule(rule: RuleApplier[T], path: CboPath[T]): Unit = {
+ rule.apply(path)
+ }
+
+ private def applyRules(): Unit = {
+ if (rules.isEmpty) {
+ return
+ }
+ val shapes = rules.map(_.shape())
+ allClusters
+ .flatMap(c => c.nodes())
+ .foreach(
+ node => findPaths(node, shapes)(path => rules.foreach(rule => applyRule(rule, path))))
+ }
+
+ private def applyEnforcerRules(): Unit = {
+ allGroups.foreach {
+ group =>
+ val constraintSet = group.propSet()
+ val enforcerRules = enforcerRuleSet.rulesOf(constraintSet)
+ if (enforcerRules.nonEmpty) {
+ val shapes = enforcerRules.map(_.shape())
+ memoState.clusterLookup()(group.clusterKey()).nodes().foreach {
+ node =>
+ findPaths(node, shapes)(
+ path => enforcerRules.foreach(rule => applyRule(rule, path)))
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/memo/ForwardMemoTable.scala b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/memo/ForwardMemoTable.scala
new file mode 100644
index 000000000000..50d27a6795e9
--- /dev/null
+++ b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/memo/ForwardMemoTable.scala
@@ -0,0 +1,201 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo.memo
+
+import io.glutenproject.cbo._
+import io.glutenproject.cbo.CboCluster.MutableCboCluster
+import io.glutenproject.cbo.memo.MemoTable.Probe
+import io.glutenproject.cbo.property.PropertySet
+import io.glutenproject.cbo.util.IndexDisjointSet
+
+import scala.collection.mutable
+
+class ForwardMemoTable[T <: AnyRef] private (override val cbo: Cbo[T])
+ extends MemoTable.Writable[T] {
+ import ForwardMemoTable._
+
+ private val groupBuffer: mutable.ArrayBuffer[CboGroup[T]] = mutable.ArrayBuffer()
+
+ private val clusterKeyBuffer: mutable.ArrayBuffer[IntClusterKey] = mutable.ArrayBuffer()
+ private val clusterBuffer: mutable.ArrayBuffer[MutableCboCluster[T]] = mutable.ArrayBuffer()
+ private val clusterDisjointSet: IndexDisjointSet = IndexDisjointSet()
+ private val groupLookup: mutable.ArrayBuffer[mutable.Map[PropertySet[T], CboGroup[T]]] =
+ mutable.ArrayBuffer()
+
+ private val clusterMergeLog: mutable.ArrayBuffer[(Int, Int)] = mutable.ArrayBuffer()
+ private var memoWriteCount: Int = 0
+
+ override def getCluster(key: CboClusterKey): MutableCboCluster[T] = {
+ val ancestor = ancestorClusterIdOf(key)
+ clusterBuffer(ancestor)
+ }
+
+ override def newCluster(): CboClusterKey = {
+ checkBufferSizes()
+ val key = IntClusterKey(clusterBuffer.size)
+ clusterKeyBuffer += key
+ clusterBuffer += MutableCboCluster(cbo)
+ clusterDisjointSet.grow()
+ groupLookup += mutable.Map()
+ key
+ }
+
+ override def groupOf(key: CboClusterKey, propSet: PropertySet[T]): CboGroup[T] = {
+ val ancestor = ancestorClusterIdOf(key)
+ val lookup = groupLookup(ancestor)
+ if (lookup.contains(propSet)) {
+ return lookup(propSet)
+ }
+ val gid = groupBuffer.size
+ val newGroup =
+ CboGroup(cbo, IntClusterKey(ancestor), gid, propSet)
+ lookup += propSet -> newGroup
+ groupBuffer += newGroup
+ memoWriteCount += 1
+ newGroup
+ }
+
+ override def getClusterPropSets(key: CboClusterKey): Set[PropertySet[T]] = {
+ val ancestor = ancestorClusterIdOf(key)
+ groupLookup(ancestor).keySet.toSet
+ }
+
+ override def addToCluster(key: CboClusterKey, node: CanonicalNode[T]): Unit = {
+ getCluster(key).add(node)
+ memoWriteCount += 1
+ }
+
+ override def mergeClusters(one: CboClusterKey, other: CboClusterKey): Unit = {
+ val oneAncestor = ancestorClusterIdOf(one)
+ val otherAncestor = ancestorClusterIdOf(other)
+
+ if (oneAncestor == otherAncestor) {
+ // Already merged.
+ return
+ }
+
+ case class Merge(from: Int, to: Int)
+
+ val merge = if (oneAncestor > otherAncestor) {
+ Merge(oneAncestor, otherAncestor)
+ } else {
+ Merge(otherAncestor, oneAncestor)
+ }
+
+ val fromKey = IntClusterKey(merge.from)
+ val toKey = IntClusterKey(merge.to)
+
+ val fromCluster = clusterBuffer(merge.from)
+ val toCluster = clusterBuffer(merge.to)
+
+ // Add absent nodes.
+ fromCluster.nodes().foreach {
+ fromNode =>
+ if (!toCluster.contains(fromNode)) {
+ toCluster.add(fromNode)
+ }
+ }
+
+ // Add absent groups.
+ val fromGroups = groupLookup(merge.from)
+ val toGroups = groupLookup(merge.to)
+ fromGroups.foreach {
+ case (fromPropSet, _) =>
+ if (!toGroups.contains(fromPropSet)) {
+ groupOf(toKey, fromPropSet)
+ }
+ }
+
+ // Forward the element in disjoint set.
+ clusterDisjointSet.forward(merge.from, merge.to)
+ clusterMergeLog += (merge.from -> merge.to)
+ memoWriteCount += 1
+ }
+
+ override def getGroup(id: Int): CboGroup[T] = groupBuffer(id)
+
+ override def allClusters(): Seq[CboClusterKey] = clusterKeyBuffer
+
+ override def allGroups(): Seq[CboGroup[T]] = groupBuffer
+
+ private def ancestorClusterIdOf(key: CboClusterKey): Int = {
+ clusterDisjointSet.find(key.id())
+ }
+
+ private def checkBufferSizes(): Unit = {
+ assert(clusterKeyBuffer.size == clusterBuffer.size)
+ assert(clusterKeyBuffer.size == clusterDisjointSet.size)
+ assert(clusterKeyBuffer.size == groupLookup.size)
+ }
+
+ override def probe(): MemoTable.Probe[T] = new ForwardMemoTable.Probe[T](this)
+
+ override def writeCount(): Int = memoWriteCount
+}
+
+object ForwardMemoTable {
+ def apply[T <: AnyRef](cbo: Cbo[T]): MemoTable.Writable[T] = new ForwardMemoTable[T](cbo)
+
+ private case class IntClusterKey(id: Int) extends CboClusterKey
+
+ private class Probe[T <: AnyRef](table: ForwardMemoTable[T]) extends MemoTable.Probe[T] {
+ private val probedClusterCount: Int = table.clusterKeyBuffer.size
+ private val probedGroupCount: Int = table.groupBuffer.size
+ private val probedMergeLogSize: Int = table.clusterMergeLog.size
+
+ override def toDiff(): Probe.Diff[T] = {
+ val newClusterCount = table.clusterKeyBuffer.size
+ val newGroupCount = table.groupBuffer.size
+ val newMergeLogSize = table.clusterMergeLog.size
+
+ assert(newClusterCount >= probedClusterCount)
+ assert(newGroupCount >= probedGroupCount)
+ assert(newMergeLogSize >= probedMergeLogSize)
+
+ // Find new clusters.
+ val newClusters = table.clusterKeyBuffer.slice(probedClusterCount, newClusterCount)
+
+ // Find resident clusters of the new groups.
+ val newGroups = table.groupBuffer.slice(probedGroupCount, newGroupCount)
+ val clustersOfNewGroups = newGroups.map(g => g.clusterKey())
+
+ // Find all the affected clusters, if cluster-merge happened.
+ val newMergeLogs = table.clusterMergeLog.slice(probedMergeLogSize, newMergeLogSize)
+ val affectedClustersDuringMerging = newMergeLogs
+ .flatMap {
+ case (from, to) =>
+ table.clusterDisjointSet.setOf(to)
+ }
+ .map(index => table.clusterKeyBuffer(index))
+
+ val changedClusters =
+ (clustersOfNewGroups.toSet ++ affectedClustersDuringMerging) -- newClusters
+ // We consider a existing cluster with new groups changed.
+ Probe.Diff(changedClusters)
+ }
+ }
+
+ implicit class CboClusterKeyImplicits(key: CboClusterKey) {
+ def id(): Int = {
+ asIntKey().id
+ }
+
+ private def asIntKey(): IntClusterKey = {
+ key.asInstanceOf[IntClusterKey]
+ }
+ }
+}
diff --git a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/memo/Memo.scala b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/memo/Memo.scala
new file mode 100644
index 000000000000..7cbee7c39e8f
--- /dev/null
+++ b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/memo/Memo.scala
@@ -0,0 +1,226 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo.memo
+
+import io.glutenproject.cbo._
+import io.glutenproject.cbo.CboCluster.ImmutableCboCluster
+import io.glutenproject.cbo.property.PropertySet
+import io.glutenproject.cbo.util.CanonicalNodeMap
+import io.glutenproject.cbo.vis.GraphvizVisualizer
+
+trait MemoLike[T <: AnyRef] {
+ def memorize(node: T, constraintSet: PropertySet[T]): CboGroup[T]
+}
+
+trait Closure[T <: AnyRef] {
+ def openFor(node: CanonicalNode[T]): MemoLike[T]
+}
+
+trait Memo[T <: AnyRef] extends Closure[T] with MemoLike[T] {
+ def newState(): MemoState[T]
+ def doExhaustively(func: => Unit): Unit
+}
+
+trait UnsafeMemo[T <: AnyRef] extends Memo[T] {
+ def table(): MemoTable[T]
+}
+
+object Memo {
+ def apply[T <: AnyRef](cbo: Cbo[T]): Memo[T] = {
+ new CboMemo[T](cbo)
+ }
+
+ def unsafe[T <: AnyRef](cbo: Cbo[T]): UnsafeMemo[T] = {
+ new CboMemo[T](cbo)
+ }
+
+ private class CboMemo[T <: AnyRef](val cbo: Cbo[T]) extends UnsafeMemo[T] {
+ import CboMemo._
+ private val memoTable: MemoTable.Writable[T] = MemoTable.create(cbo)
+ private val cache: NodeToClusterMap[T] = new NodeToClusterMap(cbo)
+
+ private def newCluster(): CboClusterKey = {
+ memoTable.newCluster()
+ }
+
+ private def addToCluster(clusterKey: CboClusterKey, can: CanonicalNode[T]): Unit = {
+ assert(!cache.contains(can))
+ cache.put(can, clusterKey)
+ memoTable.addToCluster(clusterKey, can)
+ }
+
+ // Replace node's children with node groups. When a group doesn't exist, create it.
+ private def canonizeUnsafe(node: T, constraintSet: PropertySet[T], depth: Int): T = {
+ assert(depth >= 1)
+ if (depth > 1) {
+ return cbo.withNewChildren(
+ node,
+ cbo.planModel
+ .childrenOf(node)
+ .zip(cbo.propertySetFactory().childrenConstraintSets(constraintSet, node))
+ .map {
+ case (child, constraintSet) =>
+ canonizeUnsafe(child, constraintSet, depth - 1)
+ }
+ )
+ }
+ assert(depth == 1)
+ val childrenGroups: Seq[CboGroup[T]] = cbo.planModel
+ .childrenOf(node)
+ .zip(cbo.propertySetFactory().childrenConstraintSets(constraintSet, node))
+ .map {
+ case (child, childConstraintSet) =>
+ memorize(child, childConstraintSet)
+ }
+ val newNode =
+ cbo.withNewChildren(node, childrenGroups.map(group => group.self()))
+ newNode
+ }
+
+ private def canonize(node: T, constraintSet: PropertySet[T]): CanonicalNode[T] = {
+ CanonicalNode(cbo, canonizeUnsafe(node, constraintSet, 1))
+ }
+
+ private def insert(n: T, constraintSet: PropertySet[T]): CboClusterKey = {
+ if (cbo.planModel.isGroupLeaf(n)) {
+ val plainGroup = memoTable.allGroups()(cbo.planModel.getGroupId(n))
+ return plainGroup.clusterKey()
+ }
+
+ val node = canonize(n, constraintSet)
+
+ if (cache.contains(node)) {
+ cache.get(node)
+ } else {
+ // Node not yet added to cluster.
+ val clusterKey = newCluster()
+ addToCluster(clusterKey, node)
+ clusterKey
+ }
+ }
+
+ override def memorize(node: T, constraintSet: PropertySet[T]): CboGroup[T] = {
+ val clusterKey = insert(node, constraintSet)
+ val prevGroupCount = memoTable.allGroups().size
+ val out = memoTable.groupOf(clusterKey, constraintSet)
+ val newGroupCount = memoTable.allGroups().size
+ assert(newGroupCount >= prevGroupCount)
+ out
+ }
+
+ override def openFor(node: CanonicalNode[T]): MemoLike[T] = {
+ assert(cache.contains(node))
+ val targetCluster = cache.get(node)
+ new InCusterMemo[T](this, targetCluster)
+ }
+
+ override def newState(): MemoState[T] = {
+ memoTable.newState()
+ }
+
+ override def table(): MemoTable[T] = memoTable
+
+ override def doExhaustively(func: => Unit): Unit = {
+ memoTable.doExhaustively(func)
+ }
+ }
+
+ private object CboMemo {
+ private class InCusterMemo[T <: AnyRef](parent: CboMemo[T], preparedCluster: CboClusterKey)
+ extends MemoLike[T] {
+
+ private def insert(node: T, constraintSet: PropertySet[T]): Unit = {
+ val can = parent.canonize(node, constraintSet)
+ if (parent.cache.contains(can)) {
+ val cachedCluster = parent.cache.get(can)
+ if (cachedCluster == preparedCluster) {
+ return
+ }
+ // The new node already memorized to memo, but in the different cluster
+ // with the input node. Merge the two clusters.
+ //
+ // TODO: Traversal up the tree to do more merges.
+ parent.memoTable.mergeClusters(cachedCluster, preparedCluster)
+ // Since new node already memorized, we don't have to add it to either of the clusters
+ // anymore.
+ return
+ }
+ parent.addToCluster(preparedCluster, can)
+ }
+
+ override def memorize(node: T, constraintSet: PropertySet[T]): CboGroup[T] = {
+ insert(node, constraintSet)
+ parent.memoTable.groupOf(preparedCluster, constraintSet)
+ }
+ }
+ }
+
+ private class NodeToClusterMap[T <: AnyRef](cbo: Cbo[T])
+ extends CanonicalNodeMap[T, CboClusterKey](cbo)
+}
+
+trait MemoStore[T <: AnyRef] {
+ def getCluster(key: CboClusterKey): CboCluster[T]
+ def getGroup(id: Int): CboGroup[T]
+}
+
+object MemoStore {
+ implicit class MemoStoreImplicits[T <: AnyRef](store: MemoStore[T]) {
+ def asGroupSupplier(): Int => CboGroup[T] = {
+ store.getGroup
+ }
+ }
+}
+
+trait MemoState[T <: AnyRef] extends MemoStore[T] {
+ def cbo(): Cbo[T]
+ def clusterLookup(): Map[CboClusterKey, CboCluster[T]]
+ def allClusters(): Iterable[CboCluster[T]]
+ def allGroups(): Seq[CboGroup[T]]
+}
+
+object MemoState {
+ def apply[T <: AnyRef](
+ cbo: Cbo[T],
+ clusterLookup: Map[CboClusterKey, ImmutableCboCluster[T]],
+ allGroups: Seq[CboGroup[T]]): MemoState[T] = {
+ MemoStateImpl(cbo, clusterLookup, allGroups)
+ }
+
+ private case class MemoStateImpl[T <: AnyRef](
+ override val cbo: Cbo[T],
+ override val clusterLookup: Map[CboClusterKey, ImmutableCboCluster[T]],
+ override val allGroups: Seq[CboGroup[T]])
+ extends MemoState[T] {
+ private val allClustersCopy = clusterLookup.values
+
+ override def getCluster(key: CboClusterKey): CboCluster[T] = clusterLookup(key)
+ override def getGroup(id: Int): CboGroup[T] = allGroups(id)
+ override def allClusters(): Iterable[CboCluster[T]] = allClustersCopy
+ }
+
+ implicit class MemoStateImplicits[T <: AnyRef](state: MemoState[T]) {
+
+ def formatGraphvizWithBest(best: Best[T]): String = {
+ GraphvizVisualizer(state.cbo(), state, best).format()
+ }
+
+ def formatGraphvizWithoutBest(rootGroupId: Int): String = {
+ GraphvizVisualizer(state.cbo(), state, rootGroupId).format()
+ }
+ }
+}
diff --git a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/memo/MemoTable.scala b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/memo/MemoTable.scala
new file mode 100644
index 000000000000..755d2c15d95a
--- /dev/null
+++ b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/memo/MemoTable.scala
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo.memo
+
+import io.glutenproject.cbo._
+import io.glutenproject.cbo.CboCluster.ImmutableCboCluster
+import io.glutenproject.cbo.property.PropertySet
+
+sealed trait MemoTable[T <: AnyRef] extends MemoStore[T] {
+ import MemoTable._
+
+ def cbo: Cbo[T]
+
+ override def getCluster(key: CboClusterKey): CboCluster[T]
+ override def getGroup(id: Int): CboGroup[T]
+
+ def allClusters(): Seq[CboClusterKey]
+ def allGroups(): Seq[CboGroup[T]]
+
+ def getClusterPropSets(key: CboClusterKey): Set[PropertySet[T]]
+
+ def probe(): Probe[T]
+
+ def writeCount(): Int
+}
+
+object MemoTable {
+ def create[T <: AnyRef](cbo: Cbo[T]): Writable[T] = ForwardMemoTable(cbo)
+
+ trait Writable[T <: AnyRef] extends MemoTable[T] {
+ def newCluster(): CboClusterKey
+ def groupOf(key: CboClusterKey, propertySet: PropertySet[T]): CboGroup[T]
+
+ def addToCluster(key: CboClusterKey, node: CanonicalNode[T]): Unit
+ def mergeClusters(one: CboClusterKey, other: CboClusterKey): Unit
+ }
+
+ trait Probe[T <: AnyRef] {
+ import Probe._
+ def toDiff(): Diff[T]
+ }
+
+ object Probe {
+ trait Diff[T <: AnyRef] {
+ def changedClusters(): Set[CboClusterKey]
+ }
+
+ object Diff {
+ def apply[T <: AnyRef](changedClusters: Set[CboClusterKey]): Diff[T] = DiffImpl(
+ changedClusters)
+ private case class DiffImpl[T <: AnyRef](override val changedClusters: Set[CboClusterKey])
+ extends Diff[T]
+ }
+ }
+
+ implicit class MemoTableImplicits[T <: AnyRef](table: MemoTable[T]) {
+ def newState(): MemoState[T] = {
+ val immutableClusters = table
+ .allClusters()
+ .map(key => key -> ImmutableCboCluster(table.cbo, table.getCluster(key)))
+ .toMap
+ MemoState(table.cbo, immutableClusters, table.allGroups())
+ }
+
+ def doExhaustively(func: => Unit): Unit = {
+ while (true) {
+ val prevWriteCount = table.writeCount()
+ func
+ val writeCount = table.writeCount()
+ assert(writeCount >= prevWriteCount)
+ if (writeCount == prevWriteCount) {
+ return
+ }
+ }
+ }
+ }
+}
diff --git a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/path/CboPath.scala b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/path/CboPath.scala
new file mode 100644
index 000000000000..5b623bd4a5ee
--- /dev/null
+++ b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/path/CboPath.scala
@@ -0,0 +1,195 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo.path
+
+import io.glutenproject.cbo._
+import io.glutenproject.cbo.memo.MemoStore
+
+trait CboPath[T <: AnyRef] {
+ def cbo(): Cbo[T]
+ def keys(): PathKeySet
+ def height(): Int
+ def node(): CboPath.PathNode[T]
+ def plan(): T
+}
+
+object CboPath {
+ val INF_DEPTH: Int = Int.MaxValue
+
+ trait PathNode[T <: AnyRef] {
+ def self(): CboNode[T]
+ def children(): Seq[PathNode[T]]
+ }
+
+ object PathNode {
+ def apply[T <: AnyRef](node: CboNode[T], children: Seq[PathNode[T]]): PathNode[T] = {
+ PathNodeImpl(node, children)
+ }
+ }
+
+ implicit class PathNodeImplicits[T <: AnyRef](pNode: CboPath.PathNode[T]) {
+ def zipChildrenWithGroupIds(): Seq[(CboPath.PathNode[T], Int)] = {
+ pNode
+ .children()
+ .zip(pNode.self().asCanonical().getChildrenGroupIds())
+ }
+
+ def zipChildrenWithGroups(
+ allGroups: Int => CboGroup[T]): Seq[(CboPath.PathNode[T], CboGroup[T])] = {
+ pNode
+ .children()
+ .zip(pNode.self().asCanonical().getChildrenGroups(allGroups).map(_.group(allGroups)))
+ }
+ }
+
+ private def apply[T <: AnyRef](
+ cbo: Cbo[T],
+ keys: PathKeySet,
+ height: Int,
+ node: CboPath.PathNode[T]): CboPath[T] = {
+ CboPathImpl(cbo, keys, height, node)
+ }
+
+ // Returns none if children doesn't share at least one path key.
+ def apply[T <: AnyRef](
+ cbo: Cbo[T],
+ node: CboNode[T],
+ children: Seq[CboPath[T]]): Option[CboPath[T]] = {
+ assert(children.forall(_.cbo() eq cbo))
+
+ val newKeysUnsafe = children.map(_.keys().keys()).reduce[Set[PathKey]] {
+ case (one, other) =>
+ one.intersect(other)
+ }
+ if (newKeysUnsafe.isEmpty) {
+ return None
+ }
+ val newKeys = PathKeySet(newKeysUnsafe)
+ Some(
+ CboPath(
+ cbo,
+ newKeys,
+ 1 + children.map(_.height()).reduceOption(_ max _).getOrElse(0),
+ PathNode(node, children.map(_.node()))))
+ }
+
+ def zero[T <: AnyRef](cbo: Cbo[T], keys: PathKeySet, group: GroupNode[T]): CboPath[T] = {
+ CboPath(cbo, keys, 0, PathNode(group, List.empty))
+ }
+
+ def one[T <: AnyRef](
+ cbo: Cbo[T],
+ keys: PathKeySet,
+ allGroups: Int => CboGroup[T],
+ canonical: CanonicalNode[T]): CboPath[T] = {
+ CboPath(
+ cbo,
+ keys,
+ 1,
+ PathNode(canonical, canonical.getChildrenGroups(allGroups).map(g => PathNode(g, List.empty))))
+ }
+
+ // Aggregates paths that have same shape but different keys together.
+ // Currently not in use because of bad performance.
+ def aggregate[T <: AnyRef](cbo: Cbo[T], paths: Iterable[CboPath[T]]): Iterable[CboPath[T]] = {
+ // Scala has specialized optimization against small set of input of group-by.
+ // So it's better only to pass small inputs to this method if possible.
+ val grouped = paths.groupBy(_.node())
+ grouped.map {
+ case (node, paths) =>
+ val heights = paths.map(_.height()).toSeq.distinct
+ assert(heights.size == 1)
+ val height = heights.head
+ val keys = paths.map(_.keys().keys()).reduce[Set[PathKey]] {
+ case (one, other) =>
+ one.union(other)
+ }
+ CboPath(cbo, PathKeySet(keys), height, node)
+ }
+ }
+
+ def cartesianProduct[T <: AnyRef](
+ cbo: Cbo[T],
+ canonical: CanonicalNode[T],
+ children: Seq[Iterable[CboPath[T]]]): Iterable[CboPath[T]] = {
+ // Apply cartesian product across all children to get an enumeration
+ // of all possible choices of parent and children.
+ //
+ // Example:
+ //
+ // Root:
+ // n0(group1, group2)
+ // Children Input:
+ // (group1, group2)
+ // = ([n1 || n2], [n3 || n4 || n5])
+ // = ([p1 || p2.1 || p2.2], [p3 || p4 || p5]) (expanded)
+ // Children Output:
+ // [(p1, p3), (p1, p4), (p1, p5), (p2.1, p3), (p2.1, p4), (p2.1, p5),
+ // (p2.2, p3), (p2.2, p4), (p2.2, p5))] (choices)
+ // Path enumerated:
+ // [p0.1(p1, p3), p0.2(p1, p4), p0.3(p1, p5), p0.4(p2.1, p3), p0.5(p2.1, p4),
+ // p0.6(p2.1, p5), p0.7(p2.2, p3), p0.8(p2.2, p4), p0.9(p2.2, p5)] (output)
+ //
+ // TODO: Make inner builder list mutable to reduce memory usage
+ val choicesBuilder: Iterable[Seq[CboPath[T]]] = List(List.empty)
+ val choices: Iterable[Seq[CboPath[T]]] = children
+ .foldLeft(choicesBuilder) {
+ (choicesBuilder: Iterable[Seq[CboPath[T]]], child: Iterable[CboPath[T]]) =>
+ for (left <- choicesBuilder; right <- child) yield left :+ right
+ }
+
+ choices.flatMap { children: Seq[CboPath[T]] => CboPath(cbo, canonical, children) }
+ }
+
+ implicit class CboPathImplicits[T <: AnyRef](path: CboPath[T]) {
+ def dive(memoStore: MemoStore[T], extraDepth: Int): Iterable[CboPath[T]] = {
+ val accumulatedDepth = extraDepth match {
+ case CboPath.INF_DEPTH => CboPath.INF_DEPTH
+ case _ =>
+ Math.addExact(path.height(), extraDepth)
+ }
+
+ val finder = PathFinder
+ .builder(path.cbo(), memoStore)
+ .depth(accumulatedDepth)
+ .build()
+ finder.find(path)
+ }
+ }
+
+ private case class PathNodeImpl[T <: AnyRef](
+ override val self: CboNode[T],
+ override val children: Seq[PathNode[T]])
+ extends PathNode[T]
+
+ private case class CboPathImpl[T <: AnyRef](
+ override val cbo: Cbo[T],
+ override val keys: PathKeySet,
+ override val height: Int,
+ override val node: CboPath.PathNode[T])
+ extends CboPath[T] {
+ assert(height >= 0)
+ private lazy val built: T = {
+ def dfs(node: CboPath.PathNode[T]): T = {
+ cbo.withNewChildren(node.self().self(), node.children().map(c => dfs(c)))
+ }
+ dfs(node)
+ }
+
+ override def plan(): T = built
+ }
+}
diff --git a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/path/OutputFilter.scala b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/path/OutputFilter.scala
new file mode 100644
index 000000000000..f6bc818c2b27
--- /dev/null
+++ b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/path/OutputFilter.scala
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo.path
+
+import io.glutenproject.cbo.{CanonicalNode, GroupNode}
+import io.glutenproject.cbo.path.FilterWizard.FilterAction
+import io.glutenproject.cbo.path.OutputWizard.OutputAction
+import io.glutenproject.cbo.util.CycleDetector
+
+trait FilterWizard[T <: AnyRef] {
+ import FilterWizard._
+ def omit(can: CanonicalNode[T]): FilterAction[T]
+ def omit(group: GroupNode[T], offset: Int, count: Int): FilterAction[T]
+}
+
+object FilterWizard {
+ sealed trait FilterAction[T <: AnyRef]
+ object FilterAction {
+ case class Omit[T <: AnyRef] private () extends FilterAction[T]
+ object Omit {
+ val INSTANCE: Omit[Null] = Omit[Null]()
+ // Enclose default constructor.
+ private def apply[T <: AnyRef](): Omit[T] = new Omit()
+ }
+ def omit[T <: AnyRef]: Omit[T] = Omit.INSTANCE.asInstanceOf[Omit[T]]
+
+ case class Continue[T <: AnyRef](newWizard: FilterWizard[T]) extends FilterAction[T]
+ }
+}
+
+object FilterWizards {
+ def omitCycles[T <: AnyRef](): FilterWizard[T] = {
+ // Compares against group ID to identify cycles.
+ OmitCycles[T](CycleDetector[GroupNode[T]]((one, other) => one.groupId() == other.groupId()))
+ }
+
+ // Cycle detection starts from the first visited group in the input path.
+ private class OmitCycles[T <: AnyRef] private (detector: CycleDetector[GroupNode[T]])
+ extends FilterWizard[T] {
+ override def omit(can: CanonicalNode[T]): FilterAction[T] = {
+ FilterAction.Continue(this)
+ }
+
+ override def omit(group: GroupNode[T], offset: Int, count: Int): FilterAction[T] = {
+ if (detector.contains(group)) {
+ return FilterAction.omit
+ }
+ FilterAction.Continue(new OmitCycles(detector.append(group)))
+ }
+ }
+
+ private object OmitCycles {
+ def apply[T <: AnyRef](detector: CycleDetector[GroupNode[T]]): OmitCycles[T] = {
+ new OmitCycles(detector)
+ }
+ }
+}
+
+object OutputFilter {
+ def apply[T <: AnyRef](
+ outputWizard: OutputWizard[T],
+ filterWizard: FilterWizard[T]): OutputWizard[T] = {
+ new OutputFilterImpl[T](outputWizard, filterWizard)
+ }
+
+ // Composite wizard works within "and" basis, to filter out
+ // the unwanted emitted paths from a certain specified output wizard
+ // by another filter wizard.
+ private class OutputFilterImpl[T <: AnyRef](
+ outputWizard: OutputWizard[T],
+ filterWizard: FilterWizard[T])
+ extends OutputWizard[T] {
+
+ override def visit(can: CanonicalNode[T]): OutputAction[T] = {
+ filterWizard.omit(can) match {
+ case FilterAction.Omit() => OutputAction.stop
+ case FilterAction.Continue(newFilterWizard) =>
+ outputWizard.visit(can) match {
+ case stop @ OutputAction.Stop(_) =>
+ stop
+ case OutputAction.Continue(drain, newOutputWizard) =>
+ OutputAction.Continue(drain, new OutputFilterImpl(newOutputWizard, newFilterWizard))
+ }
+ }
+ }
+
+ override def advance(group: GroupNode[T], offset: Int, count: Int): OutputAction[T] = {
+ filterWizard.omit(group: GroupNode[T], offset: Int, count: Int) match {
+ case FilterAction.Omit() => OutputAction.stop
+ case FilterAction.Continue(newFilterWizard) =>
+ outputWizard.advance(group, offset, count) match {
+ case stop @ OutputAction.Stop(_) => stop
+ case OutputAction.Continue(drain, newOutputWizard) =>
+ OutputAction.Continue(drain, new OutputFilterImpl(newOutputWizard, newFilterWizard))
+ }
+ }
+ }
+
+ override def withPathKey(newKey: PathKey): OutputWizard[T] =
+ new OutputFilterImpl[T](outputWizard.withPathKey(newKey), filterWizard)
+ }
+}
diff --git a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/path/OutputWizard.scala b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/path/OutputWizard.scala
new file mode 100644
index 000000000000..5ac2bb881111
--- /dev/null
+++ b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/path/OutputWizard.scala
@@ -0,0 +1,429 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo.path
+
+import io.glutenproject.cbo.{CanonicalNode, Cbo, CboGroup, GroupNode}
+import io.glutenproject.cbo.path.OutputWizard.{OutputAction, PathDrain}
+
+import scala.collection.{mutable, Seq}
+
+trait OutputWizard[T <: AnyRef] {
+ import OutputWizard._
+ // Visit a new node.
+ def visit(can: CanonicalNode[T]): OutputAction[T]
+ // The returned object is a wizard for one of the node's children at the
+ // known offset among all children.
+ def advance(group: GroupNode[T], offset: Int, count: Int): OutputAction[T]
+ // The returned wizard would be same with this wizard
+ // except it drains paths with the input path key.
+ def withPathKey(newKey: PathKey): OutputWizard[T]
+}
+
+object OutputWizard {
+ sealed trait OutputAction[T <: AnyRef] {
+ def drain(): PathDrain
+ }
+ object OutputAction {
+ case class Stop[T <: AnyRef] private (override val drain: PathDrain) extends OutputAction[T]
+ object Stop {
+ val INSTANCE: Stop[Null] = Stop[Null]()
+ // Enclose default constructor.
+ private def apply[T <: AnyRef](): Stop[T] = new Stop(PathDrain.none)
+ }
+ def stop[T <: AnyRef]: Stop[T] = Stop.INSTANCE.asInstanceOf[Stop[T]]
+
+ case class Continue[T <: AnyRef](override val drain: PathDrain, newWizard: OutputWizard[T])
+ extends OutputAction[T]
+ }
+
+ // Path drain provides possibility to lazily materialize the yielded paths using path key.
+ // Otherwise if each wizard emits its own paths during visiting, the de-dup operation
+ // will be required and could cause serious performance issues.
+ sealed trait PathDrain {
+ def isEmpty(): Boolean
+ def keysUnsafe(): Seq[PathKey]
+ }
+
+ object PathDrain {
+ private case class None[T <: AnyRef] private () extends PathDrain {
+ override def isEmpty(): Boolean = true
+ override def keysUnsafe(): Seq[PathKey] = List.empty
+ }
+ private object None {
+ val INSTANCE: None[Null] = None[Null]()
+ private def apply[T <: AnyRef](): None[T] = new None[T]()
+ }
+ def none[T <: AnyRef]: PathDrain = None.INSTANCE.asInstanceOf[None[T]]
+ case class Specific[T <: AnyRef](override val keysUnsafe: Seq[PathKey]) extends PathDrain {
+ override def isEmpty(): Boolean = keysUnsafe.isEmpty
+ }
+ private case class Trivial[T <: AnyRef] private () extends PathDrain {
+ private val k: Seq[PathKey] = List(PathKey.Trivial)
+ override def isEmpty(): Boolean = k.isEmpty
+ override def keysUnsafe(): Seq[PathKey] = k
+ }
+ private object Trivial {
+ val INSTANCE: Trivial[Null] = Trivial[Null]()
+ private def apply[T <: AnyRef](): Trivial[T] = new Trivial[T]()
+ }
+ def trivial[T <: AnyRef]: PathDrain = Trivial.INSTANCE.asInstanceOf[Trivial[T]]
+ }
+
+ implicit class OutputWizardImplicits[T <: AnyRef](wizard: OutputWizard[T]) {
+ import OutputWizardImplicits._
+
+ def filterBy(filterWizard: FilterWizard[T]): OutputWizard[T] = {
+ OutputFilter(wizard, filterWizard)
+ }
+
+ def prepareForNode(
+ cbo: Cbo[T],
+ allGroups: Int => CboGroup[T],
+ can: CanonicalNode[T]): NodePrepare[T] = {
+ new NodePrepareImpl[T](cbo, wizard, allGroups, can)
+ }
+
+ def prepareForGroup(
+ cbo: Cbo[T],
+ group: GroupNode[T],
+ offset: Int,
+ count: Int): GroupPrepare[T] = {
+ new GroupPrepareImpl[T](cbo, wizard, group, offset, count)
+ }
+ }
+
+ object OutputWizardImplicits {
+ sealed trait NodePrepare[T <: AnyRef] {
+ def visit(): Terminate[T]
+ }
+
+ sealed trait GroupPrepare[T <: AnyRef] {
+ def advance(): Terminate[T]
+ }
+
+ sealed trait Terminate[T <: AnyRef] {
+ def onContinue(extra: OutputWizard[T] => Iterable[CboPath[T]]): Iterable[CboPath[T]]
+ }
+
+ private class DrainedTerminate[T <: AnyRef](
+ action: OutputAction[T],
+ drained: Iterable[CboPath[T]])
+ extends Terminate[T] {
+ override def onContinue(
+ extra: OutputWizard[T] => Iterable[CboPath[T]]): Iterable[CboPath[T]] = {
+ action match {
+ case OutputAction.Stop(_) =>
+ drained.view
+ case OutputAction.Continue(_, newWizard) =>
+ drained.view ++ extra(newWizard)
+ }
+ }
+ }
+
+ private class NodePrepareImpl[T <: AnyRef](
+ cbo: Cbo[T],
+ wizard: OutputWizard[T],
+ allGroups: Int => CboGroup[T],
+ can: CanonicalNode[T])
+ extends NodePrepare[T] {
+ override def visit(): Terminate[T] = {
+ val action = wizard.visit(can)
+ val drained = if (action.drain().isEmpty()) {
+ List.empty
+ } else {
+ List(CboPath.one(cbo, PathKeySet(action.drain().keysUnsafe().toSet), allGroups, can))
+ }
+ new DrainedTerminate[T](action, drained)
+ }
+ }
+
+ private class GroupPrepareImpl[T <: AnyRef](
+ cbo: Cbo[T],
+ wizard: OutputWizard[T],
+ group: GroupNode[T],
+ offset: Int,
+ count: Int)
+ extends GroupPrepare[T] {
+ override def advance(): Terminate[T] = {
+ val action = wizard.advance(group, offset, count)
+ val drained = if (action.drain().isEmpty()) {
+ List.empty
+ } else {
+ List(CboPath.zero(cbo, PathKeySet(action.drain().keysUnsafe().toSet), group))
+ }
+ new DrainedTerminate[T](action, drained)
+ }
+ }
+ }
+}
+
+object OutputWizards {
+ def none[T <: AnyRef](): OutputWizard[T] = {
+ None()
+ }
+
+ def emit[T <: AnyRef](): OutputWizard[T] = {
+ Emit()
+ }
+
+ def union[T <: AnyRef](wizards: Seq[OutputWizard[T]]): OutputWizard[T] = {
+ Union[T](wizards)
+ }
+
+ def withMask[T <: AnyRef](mask: PathMask): OutputWizard[T] = {
+ WithMask[T](mask, 0)
+ }
+
+ def withPattern[T <: AnyRef](pattern: Pattern[T]): OutputWizard[T] = {
+ WithPattern[T](pattern)
+ }
+
+ def withMaxDepth[T <: AnyRef](depth: Int): OutputWizard[T] = {
+ WithMaxDepth[T](depth)
+ }
+
+ private class None[T <: AnyRef]() extends OutputWizard[T] {
+ override def visit(can: CanonicalNode[T]): OutputAction[T] = {
+ OutputAction.Stop(PathDrain.none)
+ }
+
+ override def advance(group: GroupNode[T], offset: Int, count: Int): OutputAction[T] =
+ OutputAction.Stop(PathDrain.none)
+
+ override def withPathKey(newKey: PathKey): OutputWizard[T] = this
+ }
+
+ private object None {
+ def apply[T <: AnyRef](): None[T] = new None[T]()
+ }
+
+ private class Emit[T <: AnyRef](drain: PathDrain) extends OutputWizard[T] {
+ override def visit(can: CanonicalNode[T]): OutputAction[T] = {
+ if (can.isLeaf()) {
+ return OutputAction.Stop(drain)
+ }
+ OutputAction.Continue(PathDrain.none, this)
+ }
+
+ override def advance(group: GroupNode[T], offset: Int, count: Int): OutputAction[T] =
+ OutputAction.Continue(PathDrain.none, this)
+
+ override def withPathKey(newKey: PathKey): OutputWizard[T] =
+ new Emit[T](PathDrain.Specific(List(newKey)))
+ }
+
+ private object Emit {
+ def apply[T <: AnyRef](): Emit[T] = new Emit[T](PathDrain.trivial)
+ }
+
+ // Composite wizard works within "or" basis, which means,
+ // when one of the sub-wizards yield "continue",
+ // then itself yields continue.
+ private class Union[T <: AnyRef] private (wizards: Seq[OutputWizard[T]]) extends OutputWizard[T] {
+ import Union._
+ assert(wizards.nonEmpty)
+
+ private def act(actions: Seq[OutputAction[T]]): OutputAction[T] = {
+ val drainBuffer = mutable.ListBuffer[PathDrain]()
+ val newWizardBuffer = mutable.ListBuffer[OutputWizard[T]]()
+
+ val state: State = actions.foldLeft[State](ContinueNotFound) {
+ case (_, OutputAction.Continue(drain, newWizard)) =>
+ drainBuffer += drain
+ newWizardBuffer += newWizard
+ ContinueFound
+ case (s, OutputAction.Stop(drain)) =>
+ drainBuffer += drain
+ s
+ }
+
+ val newWizards = newWizardBuffer
+ val newDrain = PathDrain.Specific(drainBuffer.flatMap(_.keysUnsafe()))
+ state match {
+ // All omits.
+ case ContinueNotFound => OutputAction.Stop(newDrain)
+ // At least one continue.
+ case ContinueFound => OutputAction.Continue(newDrain, new Union(newWizards))
+ }
+ }
+
+ override def visit(can: CanonicalNode[T]): OutputAction[T] = {
+ val actions = wizards
+ .map(_.visit(can))
+ act(actions)
+ }
+
+ override def advance(group: GroupNode[T], offset: Int, count: Int): OutputAction[T] = {
+ val actions = wizards
+ .map(_.advance(group, offset, count))
+ act(actions)
+ }
+
+ override def withPathKey(newKey: PathKey): OutputWizard[T] =
+ new Union[T](wizards.map(w => w.withPathKey(newKey)))
+ }
+
+ private object Union {
+ def apply[T <: AnyRef](wizards: Seq[OutputWizard[T]]): Union[T] = {
+ new Union(wizards)
+ }
+
+ sealed private trait State
+ private case object ContinueNotFound extends State
+ private case object ContinueFound extends State
+ }
+
+ // Prune paths within the path mask.
+ //
+ // Example:
+ //
+ // The Tree:
+ //
+ // A
+ // |- B
+ // |- C
+ // |- D
+ // \- E
+ // \- F
+ //
+ // Mask 1:
+ // [3, 0, 0, 0]
+ //
+ // Mask 1 output:
+ //
+ // A
+ // |- B
+ // |- C
+ // \- F
+ //
+ // Mask 2:
+ // [3, 0, 2, 0, 0, 0]
+ //
+ // Mask 2 output:
+ //
+ // A
+ // |- B
+ // |- C
+ // |- D
+ // \- E
+ // \- F
+ private class WithMask[T <: AnyRef] private (drain: PathDrain, mask: PathMask, ele: Int)
+ extends OutputWizard[T] {
+
+ override def visit(can: CanonicalNode[T]): OutputAction[T] = {
+ if (can.isLeaf()) {
+ return OutputAction.Stop(drain)
+ }
+ OutputAction.Continue(PathDrain.none, this)
+ }
+
+ override def advance(group: GroupNode[T], offset: Int, count: Int): OutputAction[T] = {
+ var skipCursor = ele + 1
+ (0 until offset).foreach(_ => skipCursor = mask.skip(skipCursor))
+ if (mask.isAny(skipCursor)) {
+ return OutputAction.Stop(drain)
+ }
+ OutputAction.Continue(PathDrain.none, new WithMask[T](drain, mask, skipCursor))
+ }
+
+ override def withPathKey(newKey: PathKey): OutputWizard[T] =
+ new WithMask[T](PathDrain.Specific(List(newKey)), mask, ele)
+ }
+
+ private object WithMask {
+ def apply[T <: AnyRef](mask: PathMask, cursor: Int): WithMask[T] = {
+ new WithMask(PathDrain.Specific(List(PathKey.random())), mask, cursor)
+ }
+ }
+
+ // TODO: Document
+ private class WithPattern[T <: AnyRef] private (
+ drain: PathDrain,
+ pattern: Pattern[T],
+ pNode: Pattern.Node[T])
+ extends OutputWizard[T] {
+
+ override def visit(can: CanonicalNode[T]): OutputAction[T] = {
+ // Prune should be done in #advance.
+ assert(!pNode.skip())
+ if (pNode.abort(can)) {
+ return OutputAction.stop
+ }
+ if (!pNode.matches(can)) {
+ return OutputAction.stop
+ }
+ if (can.isLeaf()) {
+ return OutputAction.Stop(drain)
+ }
+ OutputAction.Continue(PathDrain.none, this)
+ }
+
+ override def advance(group: GroupNode[T], offset: Int, count: Int): OutputAction[T] = {
+ // Omit should be done in #advance.
+ val child = pNode.children(count)(offset)
+ if (child.skip()) {
+ return OutputAction.Stop(drain)
+ }
+ OutputAction.Continue(PathDrain.none, new WithPattern(drain, pattern, child))
+ }
+
+ override def withPathKey(newKey: PathKey): OutputWizard[T] =
+ new WithPattern[T](PathDrain.Specific(List(newKey)), pattern, pNode)
+ }
+
+ private object WithPattern {
+ def apply[T <: AnyRef](pattern: Pattern[T]): WithPattern[T] = {
+ new WithPattern(PathDrain.Specific(List(PathKey.random())), pattern, pattern.root())
+ }
+ }
+
+ // "Depth" is similar to path's "height" but it mainly describes about the
+ // distance between pathfinder from the root node.
+ private class WithMaxDepth[T <: AnyRef] private (drain: PathDrain, depth: Int, currentDepth: Int)
+ extends OutputWizard[T] {
+
+ override def visit(can: CanonicalNode[T]): OutputAction[T] = {
+ assert(
+ currentDepth <= depth,
+ "Current depth already larger than the maximum depth to prune. " +
+ "It probably because a zero depth was specified for path finding."
+ )
+ if (can.isLeaf()) {
+ return OutputAction.Stop(drain)
+ }
+ OutputAction.Continue(PathDrain.none, this)
+ }
+
+ override def advance(group: GroupNode[T], offset: Int, count: Int): OutputAction[T] = {
+ assert(currentDepth <= depth)
+ val nextDepth = currentDepth + 1
+ if (nextDepth > depth) {
+ return OutputAction.Stop(drain)
+ }
+ OutputAction.Continue(PathDrain.none, new WithMaxDepth(drain, depth, nextDepth))
+ }
+
+ override def withPathKey(newKey: PathKey): OutputWizard[T] =
+ new WithMaxDepth[T](PathDrain.Specific(List(newKey)), depth, currentDepth)
+ }
+
+ private object WithMaxDepth {
+ def apply[T <: AnyRef](depth: Int): WithMaxDepth[T] = {
+ new WithMaxDepth(PathDrain.Specific(List(PathKey.random())), depth, 1)
+ }
+ }
+}
diff --git a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/path/PathFinder.scala b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/path/PathFinder.scala
new file mode 100644
index 000000000000..86494dacfc92
--- /dev/null
+++ b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/path/PathFinder.scala
@@ -0,0 +1,191 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo.path
+
+import io.glutenproject.cbo.{CanonicalNode, Cbo, GroupNode}
+import io.glutenproject.cbo.memo.MemoStore
+
+import scala.collection.mutable
+
+trait PathFinder[T <: AnyRef] {
+ def find(base: CanonicalNode[T]): Iterable[CboPath[T]]
+ def find(base: CboPath[T]): Iterable[CboPath[T]]
+}
+
+object PathFinder {
+ def apply[T <: AnyRef](cbo: Cbo[T], memoStore: MemoStore[T]): PathFinder[T] = {
+ builder(cbo, memoStore).build()
+ }
+
+ def builder[T <: AnyRef](cbo: Cbo[T], memoStore: MemoStore[T]): Builder[T] = {
+ Builder[T](cbo, memoStore)
+ }
+
+ class Builder[T <: AnyRef] private (cbo: Cbo[T], memoStore: MemoStore[T]) {
+ private val filterWizards = mutable.ListBuffer[FilterWizard[T]](FilterWizards.omitCycles())
+ private val outputWizards = mutable.ListBuffer[OutputWizard[T]]()
+
+ def depth(depth: Int): Builder[T] = {
+ outputWizards += OutputWizards.withMaxDepth(depth)
+ this
+ }
+
+ def filter(wizard: FilterWizard[T]): Builder[T] = {
+ filterWizards += wizard
+ this
+ }
+
+ def output(wizard: OutputWizard[T]): Builder[T] = {
+ outputWizards += wizard
+ this
+ }
+
+ def build(): PathFinder[T] = {
+ if (outputWizards.isEmpty) {
+ outputWizards += OutputWizards.emit()
+ }
+ val allOutputs = OutputWizards.union(outputWizards)
+ val wizard = filterWizards.foldLeft(allOutputs) {
+ (outputWizard, filterWizard) => outputWizard.filterBy(filterWizard)
+ }
+ PathEnumerator(cbo, memoStore, wizard)
+ }
+ }
+
+ private object Builder {
+ def apply[T <: AnyRef](cbo: Cbo[T], memoStore: MemoStore[T]): Builder[T] = {
+ new Builder(cbo, memoStore)
+ }
+ }
+
+ // Using children's enumerated paths recursively to enumerate the paths of the current node.
+ // This works like from bottom up to assemble all possible paths.
+ private class PathEnumerator[T <: AnyRef] private (
+ cbo: Cbo[T],
+ memoStore: MemoStore[T],
+ wizard: OutputWizard[T])
+ extends PathFinder[T] {
+
+ override def find(canonical: CanonicalNode[T]): Iterable[CboPath[T]] = {
+ val all =
+ wizard.prepareForNode(cbo, memoStore.asGroupSupplier(), canonical).visit().onContinue {
+ newWizard => enumerateFromNode(canonical, newWizard)
+ }
+ all
+ }
+
+ override def find(base: CboPath[T]): Iterable[CboPath[T]] = {
+ val can = base.node().self().asCanonical()
+ val all = wizard.prepareForNode(cbo, memoStore.asGroupSupplier(), can).visit().onContinue {
+ newWizard => diveFromNode(base.height(), base.node(), newWizard)
+ }
+ all
+ }
+
+ private def enumerateFromGroup(
+ group: GroupNode[T],
+ wizard: OutputWizard[T]): Iterable[CboPath[T]] = {
+ group
+ .group(memoStore.asGroupSupplier())
+ .nodes(memoStore)
+ .flatMap(
+ can => {
+ wizard.prepareForNode(cbo, memoStore.asGroupSupplier(), can).visit().onContinue {
+ newWizard => enumerateFromNode(can, newWizard)
+ }
+ })
+ }
+
+ private def enumerateFromNode(
+ canonical: CanonicalNode[T],
+ wizard: OutputWizard[T]): Iterable[CboPath[T]] = {
+ val childrenGroups = canonical.getChildrenGroups(memoStore.asGroupSupplier())
+ if (childrenGroups.isEmpty) {
+ // It's a canonical leaf node.
+ return List.empty
+ }
+ // It's a canonical branch node.
+ val expandedChildren: Seq[Iterable[CboPath[T]]] =
+ childrenGroups.zipWithIndex.map {
+ case (childGroup, index) =>
+ wizard
+ .prepareForGroup(cbo, childGroup, index, childrenGroups.size)
+ .advance()
+ .onContinue(newWizard => enumerateFromGroup(childGroup, newWizard))
+ }
+ CboPath.cartesianProduct(cbo, canonical, expandedChildren)
+ }
+
+ private def diveFromGroup(
+ depth: Int,
+ gpn: GroupedPathNode[T],
+ wizard: OutputWizard[T]): Iterable[CboPath[T]] = {
+ assert(depth >= 0)
+ if (depth == 0) {
+ assert(gpn.node.self().isGroup)
+ return enumerateFromGroup(gpn.group, wizard)
+ }
+
+ assert(gpn.node.self().isCanonical)
+ val canonical = gpn.node.self().asCanonical()
+
+ wizard.prepareForNode(cbo, memoStore.asGroupSupplier(), canonical).visit().onContinue {
+ newWizard => diveFromNode(depth, gpn.node, newWizard)
+ }
+ }
+
+ private def diveFromNode(
+ depth: Int,
+ node: CboPath.PathNode[T],
+ wizard: OutputWizard[T]): Iterable[CboPath[T]] = {
+ assert(depth >= 1)
+ assert(node.self().isCanonical)
+ val canonical = node.self().asCanonical()
+ val children = node.children()
+ if (children.isEmpty) {
+ // It's a canonical leaf node.
+ return List.empty
+ }
+
+ val childrenGroups = canonical.getChildrenGroups(memoStore.asGroupSupplier())
+ CboPath.cartesianProduct(
+ cbo,
+ canonical,
+ children.zip(childrenGroups).zipWithIndex.map {
+ case ((child, childGroup), index) =>
+ wizard
+ .prepareForGroup(cbo, childGroup, index, childrenGroups.size)
+ .advance()
+ .onContinue {
+ newWizard => diveFromGroup(depth - 1, GroupedPathNode(childGroup, child), newWizard)
+ }
+ }
+ )
+ }
+ }
+
+ private object PathEnumerator {
+ def apply[T <: AnyRef](
+ cbo: Cbo[T],
+ memoStore: MemoStore[T],
+ wizard: OutputWizard[T]): PathEnumerator[T] = {
+ new PathEnumerator(cbo, memoStore, wizard)
+ }
+ }
+
+ private case class GroupedPathNode[T <: AnyRef](group: GroupNode[T], node: CboPath.PathNode[T])
+}
diff --git a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/path/PathKey.scala b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/path/PathKey.scala
new file mode 100644
index 000000000000..bde983b3acf3
--- /dev/null
+++ b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/path/PathKey.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo.path
+
+import java.util.concurrent.atomic.AtomicLong
+
+// Path key is used to identify the corresponding children and parent nodes
+// during path finding. One path can have multiple path keys tagged to it
+// so it is made possible that we can fetch all interested paths that from
+// different wizards within a single path-finding request.
+trait PathKey {}
+
+object PathKey {
+ case object Trivial extends PathKey
+
+ def random(): PathKey = RandomKey()
+
+ private case class RandomKey private (id: Long) extends PathKey
+
+ private object RandomKey {
+ private val nextId = new AtomicLong(0)
+ def apply(): PathKey = {
+ RandomKey(nextId.getAndIncrement())
+ }
+ }
+}
+
+sealed trait PathKeySet {
+ def keys(): Set[PathKey]
+}
+
+object PathKeySet {
+ private val TRIVIAL = PathKeySetImpl(Set(PathKey.Trivial))
+
+ def apply(keys: Set[PathKey]): PathKeySet = {
+ PathKeySetImpl(keys)
+ }
+
+ def trivial: PathKeySet = {
+ TRIVIAL
+ }
+
+ private case class PathKeySetImpl(override val keys: Set[PathKey]) extends PathKeySet {
+ assert(keys.nonEmpty, "Path should at least have one key")
+ }
+}
diff --git a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/path/PathMask.scala b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/path/PathMask.scala
new file mode 100644
index 000000000000..cff370d2724d
--- /dev/null
+++ b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/path/PathMask.scala
@@ -0,0 +1,192 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo.path
+
+import scala.collection.mutable
+
+// Mask is an integer array (pre-order DFS).
+//
+// FIXME: This is not currently in use. Use pattern instead.
+// We may consider open up some API based on this once pattern-match's
+// performance doesn't meet expectations.
+trait PathMask {
+ import PathMask._
+ def get(index: Int): Digit
+ def all(): Seq[Digit]
+}
+
+object PathMask {
+ type Digit = Int
+ val ANY: Int = -1
+
+ def apply(mask: Seq[Digit]): PathMask = {
+ // Validate the mask.
+ validate(mask)
+ PathMaskImpl(mask)
+ }
+
+ private def validate(mask: Seq[Digit]): Unit = {
+ // FIXME: This is a rough validation.
+ val m = mask
+ assert(m.forall(digit => digit == ANY || digit >= 0))
+ assert(m.size == m.map(_.max(0)).sum + 1)
+ }
+
+ def union(masks: Seq[PathMask]): PathMask = {
+ unionUnsafe(masks).get
+ }
+
+ // Union two masks. Any mask satisfies one of the
+ // input masks would satisfy the output mask too.
+ //
+ // Return None if not union-able.
+ private def unionUnsafe(masks: Seq[PathMask]): Option[PathMask] = {
+ assert(masks.nonEmpty)
+ val out = masks.reduce[PathMask] {
+ case (left: PathMask, right: PathMask) =>
+ val buffer = mutable.ArrayBuffer[Digit]()
+
+ def dfs(depth: Int, lCursor: Int, rCursor: Int): Boolean = {
+ // lcc: left children count
+ // rcc: right children count
+ val lcc = left.get(lCursor)
+ val rcc = right.get(rCursor)
+
+ if (lcc == ANY || rcc == ANY) {
+ buffer += ANY
+ return true
+ }
+
+ if (lcc != rcc) {
+ return false
+ }
+
+ // cc: children count
+ val cc = lcc
+ buffer += cc
+
+ var lChildCursor = lCursor + 1
+ var rChildCursor = rCursor + 1
+ (0 until cc).foreach {
+ _ =>
+ if (!dfs(depth + 1, lChildCursor, rChildCursor)) {
+ return false
+ }
+ lChildCursor = left.skip(lChildCursor)
+ rChildCursor = right.skip(rChildCursor)
+ }
+ true
+ }
+
+ if (!dfs(0, 0, 0)) {
+ return None
+ }
+
+ PathMask(buffer)
+ }
+
+ Some(out)
+ }
+
+ private case class PathMaskImpl(mask: Seq[Digit]) extends PathMask {
+ override def get(index: Int): Digit = mask(index)
+
+ override def all(): Seq[Digit] = mask
+ }
+
+ implicit class PathMaskImplicits(mask: PathMask) {
+ def isAny(index: Int): Boolean = {
+ mask.get(index) == ANY
+ }
+
+ // Input is the index of node element of skip, then returns the
+ // index of it's next brother node element. If no remaining brothers,
+ // returns parent's next brother's index recursively in pre-order.
+ def skip(ele: Int): Int = {
+ val childrenCount = mask.get(ele)
+ if (childrenCount == ANY) {
+ return ele + 1
+ }
+ var accumulatedSkips = childrenCount + 1
+ var skipCursor = ele
+
+ def loop(): Unit = {
+ while (true) {
+ skipCursor += 1
+ accumulatedSkips -= 1
+ if (accumulatedSkips == 0) {
+ return
+ }
+ accumulatedSkips += (if (mask.isAny(skipCursor)) 0 else mask.get(skipCursor))
+ }
+ }
+
+ loop()
+ assert(accumulatedSkips == 0)
+ skipCursor
+ }
+
+ // Truncate the path within a fixed maximum depth.
+ // The nodes deeper than the depth will be normalized
+ // into its ancestor at the depth with children count value '0'.
+ def fold(maxDepth: Int): PathMask = {
+ val buffer = mutable.ArrayBuffer[Digit]()
+
+ def dfs(depth: Int, cursor: Int): Unit = {
+ if (depth == maxDepth) {
+ buffer += ANY
+ return
+ }
+ assert(depth < maxDepth)
+ if (mask.isAny(cursor)) {
+ buffer += ANY
+ return
+ }
+ val childrenCount = mask.get(cursor)
+ buffer += childrenCount
+ var childCursor = cursor + 1
+ (0 until childrenCount).foreach {
+ _ =>
+ dfs(depth + 1, childCursor)
+ childCursor = mask.skip(childCursor)
+ }
+ }
+
+ dfs(0, 0)
+
+ PathMask(buffer)
+ }
+
+ // Return the sub-mask whose root node is the node at the input index
+ // of this mask.
+ def subMaskAt(index: Int): PathMask = {
+ PathMask(mask.all().slice(index, mask.skip(index)))
+ }
+
+ // Tests if this mask satisfies another.
+ //
+ // Term 'satisfy' here means if a path is not omitted (though it can be pruned)
+ // by this mask, the the output of pruning procedure must not be omitted by the
+ // 'other' mask.
+ def satisfies(other: PathMask): Boolean = {
+ unionUnsafe(List(mask, other)) match {
+ case Some(union) if union == other => true
+ case _ => false
+ }
+ }
+ }
+}
diff --git a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/path/Pattern.scala b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/path/Pattern.scala
new file mode 100644
index 000000000000..772ea5dcb101
--- /dev/null
+++ b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/path/Pattern.scala
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo.path
+
+import io.glutenproject.cbo.CanonicalNode
+import io.glutenproject.cbo.path.CboPath.PathNode
+
+trait Pattern[T <: AnyRef] {
+ def matches(path: CboPath[T], depth: Int): Boolean
+ def root(): Pattern.Node[T]
+}
+
+object Pattern {
+ trait Matcher[T <: AnyRef] extends (T => Boolean)
+
+ trait Node[T <: AnyRef] {
+ // If abort returns true, caller should make sure not to call further methods.
+ // It provides a way to fast fail the matching before actually jumping
+ // in to #matches call.
+ def skip(): Boolean
+ def abort(node: CanonicalNode[T]): Boolean
+ def matches(node: CanonicalNode[T]): Boolean
+ def children(count: Int): Seq[Node[T]]
+ }
+
+ private case class Any[T <: AnyRef]() extends Node[Null] {
+ override def skip(): Boolean = false
+ override def abort(node: CanonicalNode[Null]): Boolean = false
+ override def matches(node: CanonicalNode[Null]): Boolean = true
+ override def children(count: Int): Seq[Node[Null]] = (0 until count).map(_ => ignore[Null])
+ }
+
+ private object Any {
+ val INSTANCE: Any[Null] = Any[Null]()
+ // Enclose default constructor.
+ private def apply[T <: AnyRef](): Any[T] = new Any()
+ }
+
+ private case class Ignore[T <: AnyRef]() extends Node[Null] {
+ override def skip(): Boolean = true
+ override def abort(node: CanonicalNode[Null]): Boolean = false
+ override def matches(node: CanonicalNode[Null]): Boolean =
+ throw new UnsupportedOperationException()
+ override def children(count: Int): Seq[Node[Null]] = throw new UnsupportedOperationException()
+ }
+
+ private object Ignore {
+ val INSTANCE: Ignore[Null] = Ignore[Null]()
+
+ // Enclose default constructor.
+ private def apply[T <: AnyRef](): Ignore[T] = new Ignore()
+ }
+
+ private case class Branch[T <: AnyRef](matcher: Matcher[T], children: Seq[Node[T]])
+ extends Node[T] {
+ override def skip(): Boolean = false
+ override def abort(node: CanonicalNode[T]): Boolean = node.childrenCount != children.size
+ override def matches(node: CanonicalNode[T]): Boolean = matcher(node.self())
+ override def children(count: Int): Seq[Node[T]] = {
+ assert(count == children.size)
+ children
+ }
+ }
+
+ def any[T <: AnyRef]: Node[T] = Any.INSTANCE.asInstanceOf[Node[T]]
+ def ignore[T <: AnyRef]: Node[T] = Ignore.INSTANCE.asInstanceOf[Node[T]]
+ def node[T <: AnyRef](matcher: Matcher[T], children: Node[T]*): Node[T] =
+ Branch(matcher, children.toSeq)
+ def leaf[T <: AnyRef](matcher: Matcher[T]): Node[T] = Branch(matcher, List.empty)
+
+ implicit class NodeImplicits[T <: AnyRef](node: Node[T]) {
+ def build(): Pattern[T] = {
+ PatternImpl(node)
+ }
+ }
+
+ private case class PatternImpl[T <: AnyRef](root: Node[T]) extends Pattern[T] {
+ override def matches(path: CboPath[T], depth: Int): Boolean = {
+ assert(depth >= 1)
+ assert(depth <= path.height())
+ def dfs(remainingDepth: Int, patternN: Node[T], n: PathNode[T]): Boolean = {
+ assert(remainingDepth >= 0)
+ assert(n.self().isCanonical)
+ if (remainingDepth == 0) {
+ return true
+ }
+ val can = n.self().asCanonical()
+ if (patternN.abort(can)) {
+ return false
+ }
+ if (patternN.skip()) {
+ return true
+ }
+ if (!patternN.matches(can)) {
+ return false
+ }
+ // Pattern matches the current node.
+ val nc = n.children()
+ val patternNc = patternN.children(nc.size)
+ assert(
+ patternNc.size == nc.size,
+ "A node in pattern doesn't match the node in input path's children size. " +
+ "This might because the input path is not inferred by this pattern. " +
+ "It's currently not a valid use case by design."
+ )
+ if (
+ patternNc.zip(nc).exists {
+ case (cPatternN, cN) =>
+ !dfs(remainingDepth - 1, cPatternN, cN)
+ }
+ ) {
+ return false
+ }
+ true
+ }
+ dfs(depth, root, path.node())
+ }
+ }
+}
diff --git a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/property/PropertySet.scala b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/property/PropertySet.scala
new file mode 100644
index 000000000000..6acb038090ec
--- /dev/null
+++ b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/property/PropertySet.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo.property
+
+import io.glutenproject.cbo.{Property, PropertyDef}
+
+trait PropertySet[T <: AnyRef] {
+ def get[P <: Property[T]](property: PropertyDef[T, P]): P
+ def getMap: Map[PropertyDef[T, _ <: Property[T]], Property[T]]
+ def satisfies(other: PropertySet[T]): Boolean
+}
+
+object PropertySet {
+ def apply[T <: AnyRef](properties: Seq[Property[T]]): PropertySet[T] = {
+ val map: Map[PropertyDef[T, _ <: Property[T]], Property[T]] =
+ properties.map(p => (p.definition(), p)).toMap
+ assert(map.size == properties.size)
+ ImmutablePropertySet[T](map)
+ }
+
+ def apply[T <: AnyRef](
+ map: Map[PropertyDef[T, _ <: Property[T]], Property[T]]): PropertySet[T] = {
+ ImmutablePropertySet[T](map)
+ }
+
+ implicit class PropertySetImplicits[T <: AnyRef](propSet: PropertySet[T]) {
+ def withProp(property: Property[T]): PropertySet[T] = {
+ val before = propSet.getMap
+ val after = before + (property.definition() -> property)
+ assert(after.size == before.size)
+ ImmutablePropertySet[T](after)
+ }
+ }
+
+ private case class ImmutablePropertySet[T <: AnyRef](
+ map: Map[PropertyDef[T, _ <: Property[T]], Property[T]])
+ extends PropertySet[T] {
+ override def getMap: Map[PropertyDef[T, _ <: Property[T]], Property[T]] = map
+ override def satisfies(other: PropertySet[T]): Boolean = {
+ assert(map.size == other.getMap.size)
+ map.forall {
+ case (propDef, prop) =>
+ prop.satisfies(other.getMap(propDef))
+ }
+ }
+
+ override def get[P <: Property[T]](propDef: PropertyDef[T, P]): P = {
+ assert(map.contains(propDef))
+ map(propDef).asInstanceOf[P]
+ }
+
+ override def toString: String = map.values.toVector.toString()
+ }
+}
diff --git a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/rule/CboRule.scala b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/rule/CboRule.scala
new file mode 100644
index 000000000000..0ccd7aee576b
--- /dev/null
+++ b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/rule/CboRule.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo.rule
+
+trait CboRule[T <: AnyRef] {
+ def shift(node: T): Iterable[T]
+ def shape(): Shape[T]
+}
+
+object CboRule {
+ trait Factory[T <: AnyRef] {
+ def create(): Seq[CboRule[T]]
+ }
+
+ object Factory {
+ def reuse[T <: AnyRef](rules: Seq[CboRule[T]]): Factory[T] = new SimpleReuse(rules)
+
+ def none[T <: AnyRef](): Factory[T] = new SimpleReuse[T](List.empty)
+
+ private class SimpleReuse[T <: AnyRef](rules: Seq[CboRule[T]]) extends Factory[T] {
+ override def create(): Seq[CboRule[T]] = rules
+ }
+ }
+
+}
diff --git a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/rule/EnforcerRule.scala b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/rule/EnforcerRule.scala
new file mode 100644
index 000000000000..17e7eef38da3
--- /dev/null
+++ b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/rule/EnforcerRule.scala
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo.rule
+
+import io.glutenproject.cbo.{Cbo, EnforcerRuleFactory, Property, PropertyDef}
+import io.glutenproject.cbo.memo.Closure
+import io.glutenproject.cbo.property.PropertySet
+
+import scala.collection.mutable
+
+trait EnforcerRule[T <: AnyRef] {
+ def shift(node: T): Iterable[T]
+ def shape(): Shape[T]
+ def constraint(): Property[T]
+}
+
+object EnforcerRule {
+ def apply[T <: AnyRef](rule: CboRule[T], constraint: Property[T]): EnforcerRule[T] = {
+ new EnforcerRuleImpl(rule, constraint)
+ }
+
+ def builtin[T <: AnyRef](constraint: Property[T]): EnforcerRule[T] = {
+ new BuiltinEnforcerRule(constraint)
+ }
+
+ private class EnforcerRuleImpl[T <: AnyRef](
+ rule: CboRule[T],
+ override val constraint: Property[T])
+ extends EnforcerRule[T] {
+ override def shift(node: T): Iterable[T] = rule.shift(node)
+ override def shape(): Shape[T] = rule.shape()
+ }
+
+ private class BuiltinEnforcerRule[T <: AnyRef](override val constraint: Property[T])
+ extends EnforcerRule[T] {
+ override def shift(node: T): Iterable[T] = List(node)
+ override def shape(): Shape[T] = Shapes.fixedHeight(1)
+ }
+}
+
+trait EnforcerRuleSet[T <: AnyRef] {
+ def rulesOf(constraintSet: PropertySet[T]): Seq[RuleApplier[T]]
+}
+
+object EnforcerRuleSet {
+ def apply[T <: AnyRef](cbo: Cbo[T], closure: Closure[T]): EnforcerRuleSet[T] = {
+ new EnforcerRuleSetImpl(cbo, closure)
+ }
+
+ private def newEnforcerRuleFactory[T <: AnyRef](
+ cbo: Cbo[T],
+ propertyDef: PropertyDef[T, _ <: Property[T]]): EnforcerRuleFactory[T] = {
+ cbo.propertyModel.newEnforcerRuleFactory(propertyDef)
+ }
+
+ private class EnforcerRuleSetImpl[T <: AnyRef](cbo: Cbo[T], closure: Closure[T])
+ extends EnforcerRuleSet[T] {
+ private val factoryBuffer =
+ mutable.Map[PropertyDef[T, _ <: Property[T]], EnforcerRuleFactory[T]]()
+ private val buffer = mutable.Map[Property[T], Seq[RuleApplier[T]]]()
+
+ override def rulesOf(constraintSet: PropertySet[T]): Seq[RuleApplier[T]] = {
+ constraintSet.getMap.flatMap {
+ case (constraintDef, constraint) =>
+ buffer.getOrElseUpdate(
+ constraint, {
+ val factory =
+ factoryBuffer.getOrElseUpdate(
+ constraintDef,
+ newEnforcerRuleFactory(cbo, constraintDef))
+ RuleApplier(cbo, closure, EnforcerRule.builtin(constraint)) +: factory
+ .newEnforcerRules(constraint)
+ .map(rule => RuleApplier(cbo, closure, EnforcerRule(rule, constraint)))
+ }
+ )
+ }.toSeq
+ }
+ }
+}
diff --git a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/rule/RuleApplier.scala b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/rule/RuleApplier.scala
new file mode 100644
index 000000000000..4df1deac41a7
--- /dev/null
+++ b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/rule/RuleApplier.scala
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo.rule
+
+import io.glutenproject.cbo._
+import io.glutenproject.cbo.memo.Closure
+import io.glutenproject.cbo.path.CboPath
+import io.glutenproject.cbo.util.CanonicalNodeMap
+
+import scala.collection.mutable
+
+trait RuleApplier[T <: AnyRef] {
+ def apply(path: CboPath[T]): Unit
+ def shape(): Shape[T]
+}
+
+object RuleApplier {
+ def apply[T <: AnyRef](cbo: Cbo[T], closure: Closure[T], rule: CboRule[T]): RuleApplier[T] = {
+ new ShapeAwareRuleApplier[T](cbo, new RegularRuleApplier(cbo, closure, rule))
+ }
+
+ def apply[T <: AnyRef](
+ cbo: Cbo[T],
+ closure: Closure[T],
+ rule: EnforcerRule[T]): RuleApplier[T] = {
+ new ShapeAwareRuleApplier[T](cbo, new EnforcerRuleApplier[T](cbo, closure, rule))
+ }
+
+ private class RegularRuleApplier[T <: AnyRef](cbo: Cbo[T], closure: Closure[T], rule: CboRule[T])
+ extends RuleApplier[T] {
+ private val cache = new CanonicalNodeMap[T, mutable.Set[T]](cbo)
+
+ override def apply(path: CboPath[T]): Unit = {
+ val can = path.node().self().asCanonical()
+ val plan = path.plan()
+ val appliedPlans = cache.getOrElseUpdate(can, mutable.Set())
+ if (appliedPlans.contains(plan)) {
+ return
+ }
+ apply0(can, plan)
+ appliedPlans += plan
+ }
+
+ private def apply0(can: CanonicalNode[T], plan: T): Unit = {
+ val equivalents = rule.shift(plan)
+ equivalents.foreach {
+ equiv =>
+ closure
+ .openFor(can)
+ .memorize(equiv, cbo.propertySetFactory().get(equiv))
+ }
+ }
+
+ override def shape(): Shape[T] = rule.shape()
+ }
+
+ private class EnforcerRuleApplier[T <: AnyRef](
+ cbo: Cbo[T],
+ closure: Closure[T],
+ rule: EnforcerRule[T])
+ extends RuleApplier[T] {
+ private val cache = new CanonicalNodeMap[T, mutable.Set[T]](cbo)
+ private val constraint = rule.constraint()
+ private val constraintDef = constraint.definition()
+
+ override def apply(path: CboPath[T]): Unit = {
+ val can = path.node().self().asCanonical()
+ if (can.propSet().get(constraintDef).satisfies(constraint)) {
+ return
+ }
+ val plan = path.plan()
+ val appliedPlans = cache.getOrElseUpdate(can, mutable.Set())
+ if (appliedPlans.contains(plan)) {
+ return
+ }
+ apply0(can, plan)
+ appliedPlans += plan
+ }
+
+ private def apply0(can: CanonicalNode[T], plan: T): Unit = {
+ val propSet = cbo.propertySetFactory().get(plan)
+ val constraintSet = propSet.withProp(constraint)
+ val equivalents = rule.shift(plan)
+ equivalents.foreach {
+ equiv =>
+ closure
+ .openFor(can)
+ .memorize(equiv, constraintSet)
+ }
+ }
+
+ override def shape(): Shape[T] = rule.shape()
+ }
+
+ private class ShapeAwareRuleApplier[T <: AnyRef](cbo: Cbo[T], rule: RuleApplier[T])
+ extends RuleApplier[T] {
+ private val ruleShape = rule.shape()
+
+ override def apply(path: CboPath[T]): Unit = {
+ if (!ruleShape.identify(path)) {
+ return
+ }
+ rule.apply(path)
+ }
+
+ override def shape(): Shape[T] = ruleShape
+ }
+}
diff --git a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/rule/Shape.scala b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/rule/Shape.scala
new file mode 100644
index 000000000000..fc1bd3118159
--- /dev/null
+++ b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/rule/Shape.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo.rule
+
+import io.glutenproject.cbo.path.{CboPath, OutputWizard, OutputWizards}
+
+// Shape is an abstraction for all inputs the rule can accept.
+// Shape can be specification on pattern, height, or mask
+// to represent fuzzy, or precise structure of acceptable inputs.
+trait Shape[T <: AnyRef] {
+ def wizard(): OutputWizard[T]
+ def identify(path: CboPath[T]): Boolean
+}
+
+object Shape {}
+
+object Shapes {
+ def fixedHeight[T <: AnyRef](height: Int): Shape[T] = {
+ new FixedHeight[T](height)
+ }
+
+ def none[T <: AnyRef](): Shape[T] = {
+ new None()
+ }
+
+ private class FixedHeight[T <: AnyRef](height: Int) extends Shape[T] {
+ override def wizard(): OutputWizard[T] = OutputWizards.withMaxDepth(height)
+ override def identify(path: CboPath[T]): Boolean = path.height() == height
+ }
+
+ private class None[T <: AnyRef]() extends Shape[T] {
+ override def wizard(): OutputWizard[T] = OutputWizards.none()
+ override def identify(path: CboPath[T]): Boolean = false
+ }
+}
diff --git a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/util/CycleDetector.scala b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/util/CycleDetector.scala
new file mode 100644
index 000000000000..add1265d49af
--- /dev/null
+++ b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/util/CycleDetector.scala
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo.util
+
+trait CycleDetector[T <: Any] {
+ def append(obj: T): CycleDetector[T]
+ def contains(obj: T): Boolean
+}
+
+object CycleDetector {
+ def apply[T <: Any](equalizer: Equalizer[T]): CycleDetector[T] = {
+ new LinkedCycleDetector[T](equalizer, null.asInstanceOf[T], null)
+ }
+
+ def noop[T <: Any](): CycleDetector[T] = new NoopCycleDetector[T]()
+
+ private case class NoopCycleDetector[T <: Any]() extends CycleDetector[T] {
+ override def append(obj: T): CycleDetector[T] = this
+ override def contains(obj: T): Boolean = false
+ }
+
+ // Immutable, append-only linked list for detecting cycle during path finding.
+ // The code compares elements through a passed ordering function.
+ private case class LinkedCycleDetector[T <: Any](
+ equalizer: Equalizer[T],
+ obj: T,
+ last: LinkedCycleDetector[T])
+ extends CycleDetector[T] {
+
+ override def append(obj: T): CycleDetector[T] = {
+ LinkedCycleDetector(equalizer, obj, this)
+ }
+
+ override def contains(obj: T): Boolean = {
+ // Backtrack the linked list to find cycle.
+ assert(obj != null)
+ var cursor = this
+ while (cursor.obj != null) {
+ if (equalizer(obj, cursor.obj)) {
+ return true
+ }
+ cursor = cursor.last
+ }
+ false
+ }
+ }
+
+ type Equalizer[T <: Any] = (T, T) => Boolean
+}
diff --git a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/util/IndexDisjointSet.scala b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/util/IndexDisjointSet.scala
new file mode 100644
index 000000000000..64e83d0dc741
--- /dev/null
+++ b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/util/IndexDisjointSet.scala
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo.util
+
+import scala.collection.mutable
+
+trait IndexDisjointSet {
+ def grow(): Unit
+ def forward(from: Int, to: Int): Unit
+ def find(ele: Int): Int
+ def setOf(ele: Int): Set[Int]
+ def size(): Int
+}
+
+object IndexDisjointSet {
+ def apply[T <: Any](): IndexDisjointSet = new IndexDisjointSetImpl()
+
+ private class IndexDisjointSetImpl extends IndexDisjointSet {
+ import IndexDisjointSetImpl._
+
+ private val nodeBuffer: mutable.ArrayBuffer[Node] = mutable.ArrayBuffer()
+
+ override def grow(): Unit = nodeBuffer += new Node(nodeBuffer.size)
+
+ override def forward(from: Int, to: Int): Unit = {
+ if (from == to) {
+ // Already in one set.
+ return
+ }
+
+ val fromNode = nodeBuffer(from)
+ val toNode = nodeBuffer(to)
+ assert(fromNode.parent.isEmpty, "Only root element is allowed to forward")
+ assert(toNode.parent.isEmpty, "Only root element is allowed to forward")
+
+ fromNode.parent = Some(to)
+ toNode.children += from
+ }
+
+ private def find0(ele: Int): Node = {
+ var cursor = nodeBuffer(ele)
+ while (cursor.parent.nonEmpty) {
+ cursor = nodeBuffer(cursor.parent.get)
+ }
+ cursor
+ }
+
+ override def find(ele: Int): Int = {
+ find0(ele).index
+ }
+
+ override def setOf(ele: Int): Set[Int] = {
+ val rootNode = find0(ele)
+ val buffer = mutable.ListBuffer[Int]()
+ dfsAdd(rootNode, buffer)
+ buffer.toSet
+ }
+
+ private def dfsAdd(node: Node, buffer: mutable.ListBuffer[Int]): Unit = {
+ buffer += node.index
+ node.children.foreach(child => dfsAdd(nodeBuffer(child), buffer))
+ }
+
+ override def size(): Int = {
+ nodeBuffer.size
+ }
+
+ private def checkBound(ele: Int) = {
+ assert(ele < nodeBuffer.size, "Grow the disjoint set first")
+ }
+ }
+
+ private object IndexDisjointSetImpl {
+ private class Node(val index: Int) {
+ var parent: Option[Int] = None
+ val children: mutable.ListBuffer[Int] = mutable.ListBuffer()
+ }
+ }
+}
diff --git a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/util/NodeMap.scala b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/util/NodeMap.scala
new file mode 100644
index 000000000000..8e8e483e6c7d
--- /dev/null
+++ b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/util/NodeMap.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo.util
+
+import io.glutenproject.cbo.{CanonicalNode, Cbo}
+
+import scala.collection.mutable
+
+// Arbitrary node key.
+class NodeKey[T <: AnyRef](cbo: Cbo[T], val node: T) {
+ override def hashCode(): Int = cbo.planModel.hashCode(node)
+
+ override def equals(obj: Any): Boolean = {
+ obj match {
+ case other: NodeKey[T] => cbo.planModel.equals(node, other.node)
+ case _ => false
+ }
+ }
+
+ override def toString(): String = s"NodeKey($node)"
+}
+
+// Canonical node map.
+class CanonicalNodeMap[T <: AnyRef, V](cbo: Cbo[T]) {
+ private val map: mutable.Map[NodeKey[T], V] = mutable.Map()
+
+ def contains(node: CanonicalNode[T]): Boolean = {
+ map.contains(keyOf(node))
+ }
+
+ def put(node: CanonicalNode[T], value: V): Unit = {
+ map.put(keyOf(node), value)
+ }
+
+ def get(node: CanonicalNode[T]): V = {
+ map(keyOf(node))
+ }
+
+ def getOrElseUpdate(node: CanonicalNode[T], op: => V): V = {
+ map.getOrElseUpdate(keyOf(node), op)
+ }
+
+ private def keyOf(node: CanonicalNode[T]): NodeKey[T] = {
+ new NodeKey(cbo, node.self())
+ }
+}
diff --git a/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/vis/GraphvizVisualizer.scala b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/vis/GraphvizVisualizer.scala
new file mode 100644
index 000000000000..bdaf70e5da09
--- /dev/null
+++ b/gluten-cbo/common/src/main/scala/io/glutenproject/cbo/vis/GraphvizVisualizer.scala
@@ -0,0 +1,214 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo.vis
+
+import io.glutenproject.cbo._
+import io.glutenproject.cbo.best.BestFinder
+import io.glutenproject.cbo.memo.MemoState
+import io.glutenproject.cbo.path._
+
+import scala.collection.mutable
+
+// Visualize the planning procedure using dot language.
+class GraphvizVisualizer[T <: AnyRef](cbo: Cbo[T], memoState: MemoState[T], best: Best[T]) {
+
+ private val allGroups = memoState.allGroups()
+ private val allClusters = memoState.clusterLookup()
+
+ def format(): String = {
+ val rootGroupId = best.rootGroupId()
+ val bestPath = best.path()
+ val winnerNodes = best.winnerNodes()
+ val bestNodes = best.bestNodes()
+ val costs = best.costs()
+ val rootGroup = allGroups(rootGroupId)
+
+ val buf = new StringBuilder()
+ buf.append("digraph G {\n")
+ buf.append(" compound=true;\n")
+
+ object IsBestNode {
+ def unapply(nodeAndGroupToTest: (CanonicalNode[T], CboGroup[T])): Boolean = {
+ bestNodes.contains(InGroupNode(nodeAndGroupToTest._2.id(), nodeAndGroupToTest._1))
+ }
+ }
+
+ object IsWinnerNode {
+ def unapply(nodeAndGroupToTest: (CanonicalNode[T], CboGroup[T])): Boolean = {
+ winnerNodes.contains(InGroupNode(nodeAndGroupToTest._2.id(), nodeAndGroupToTest._1))
+ }
+ }
+
+ val clusterToGroups: mutable.Map[CboClusterKey, mutable.Set[Int]] = mutable.Map()
+
+ allGroups.foreach {
+ group => clusterToGroups.getOrElseUpdate(group.clusterKey(), mutable.Set()).add(group.id())
+ }
+
+ val groupToDotClusterId: mutable.Map[Int, Int] = mutable.Map()
+ var dotClusterId = 0
+ allClusters.foreach {
+ case (clusterKey, cluster) =>
+ buf.append(s" subgraph cluster$dotClusterId {\n")
+ dotClusterId = dotClusterId + 1
+ buf.append(s" label=${'"'}${describeCluster(clusterKey)}${'"'}\n")
+ clusterToGroups(clusterKey).map(allGroups(_)).foreach {
+ group =>
+ buf.append(s" subgraph cluster$dotClusterId {\n")
+ groupToDotClusterId += group.id() -> dotClusterId
+ dotClusterId = dotClusterId + 1
+ buf.append(s" label=${'"'}${describeGroupVerbose(group)}${'"'}\n")
+ group.nodes(memoState).foreach {
+ node =>
+ {
+ buf.append(s" ${'"'}${describeNode(costs, group, node)}${'"'}")
+ (node, group) match {
+ case IsBestNode() =>
+ buf.append(" [style=filled, fillcolor=green] ")
+ case IsWinnerNode() =>
+ buf.append(" [style=filled, fillcolor=grey] ")
+ case _ =>
+ }
+ buf.append("\n")
+ }
+ }
+ buf.append(" }\n")
+ }
+ buf.append(" }\n")
+ }
+
+ allGroups.foreach {
+ group =>
+ group.nodes(memoState).foreach {
+ node =>
+ node.getChildrenGroups(allGroups).map(_.group(allGroups)).foreach {
+ childGroup =>
+ val childGroupNodes = childGroup.nodes(memoState)
+ if (childGroupNodes.nonEmpty) {
+ val randomChild = childGroupNodes.head
+ buf.append(
+ s" ${'"'}${describeNode(costs, group, node)}${'"'} -> " +
+ s"${'"'}${describeNode(costs, childGroup, randomChild)}${'"'} " +
+ s"[lhead=${'"'}cluster${groupToDotClusterId(childGroup.id())}${'"'}]\n")
+ }
+ }
+ }
+ }
+
+ def drawBest(bestNode: CboPath.PathNode[T], bestGroup: CboGroup[T]): Unit = {
+ val canonical = bestNode.self().asCanonical()
+ bestNode
+ .zipChildrenWithGroups(allGroups)
+ .foreach {
+ case (child, childGroup) =>
+ val childCanonical = child.self().asCanonical()
+ buf.append(s" ${'"'}${describeNode(costs, bestGroup, canonical)}${'"'} -> ")
+ buf.append(s" ${'"'}${describeNode(costs, childGroup, childCanonical)}${'"'}")
+ buf.append(s" [penwidth=${'"'}3.0${'"'} color=${'"'}green${'"'}]")
+ buf.append("\n")
+ drawBest(child, childGroup)
+ }
+ }
+
+ // Usually the best path should be a valid path which doesn't have group leaf end.
+ // Since there might be cases that best path was not found for some reason and
+ // user needs the graph for debug purpose, we loose the restriction on best path
+ // here by filtering out the illegal ones.
+ val rootNode = bestPath.cboPath.node()
+ if (rootNode.self().isCanonical) {
+ drawBest(rootNode, rootGroup)
+ }
+
+ buf.append("}\n")
+ buf.toString()
+ }
+
+ private def describeCluster(cluster: CboClusterKey): String = {
+ s"[Cluster $cluster]"
+ }
+
+ private def describeGroup(group: CboGroup[T]): String = {
+ s"[Group ${group.id()}]"
+ }
+
+ private def describeGroupVerbose(group: CboGroup[T]): String = {
+ s"[Group ${group.id()}: ${group.propSet().getMap.values.toIndexedSeq}]"
+ }
+
+ private def describeNode(
+ costs: InGroupNode[T] => Option[Cost],
+ group: CboGroup[T],
+ node: CanonicalNode[T]): String = {
+ s"${describeGroup(group)}[Cost ${costs(InGroupNode(group.id(), node))
+ .map {
+ case c if cbo.isInfCost(c) => ""
+ case other => other
+ }
+ .getOrElse("N/A")}]${cbo.explain.describeNode(node.self())}"
+ }
+}
+
+object GraphvizVisualizer {
+ private class FakeBestFinder[T <: AnyRef](cbo: Cbo[T], allGroups: Int => CboGroup[T])
+ extends BestFinder[T] {
+ import FakeBestFinder._
+ override def bestOf(groupId: Int): Best[T] = {
+ new FakeBest(cbo, allGroups, groupId)
+ }
+ }
+
+ private object FakeBestFinder {
+ private class FakeBest[T <: AnyRef](
+ cbo: Cbo[T],
+ allGroups: Int => CboGroup[T],
+ rootGroupId: Int)
+ extends Best[T] {
+ override def rootGroupId(): Int = {
+ rootGroupId
+ }
+ override def bestNodes(): Set[InGroupNode[T]] = {
+ Set()
+ }
+ override def winnerNodes(): Set[InGroupNode[T]] = {
+ Set()
+ }
+ override def costs(): InGroupNode[T] => Option[Cost] = { _ => None }
+
+ override def path(): Best.KnownCostPath[T] = {
+ Best.KnownCostPath(
+ CboPath.zero(cbo, PathKeySet.trivial, GroupNode(cbo, allGroups(rootGroupId))),
+ cbo.getInfCost())
+ }
+ }
+ }
+
+ def apply[T <: AnyRef](
+ cbo: Cbo[T],
+ memoState: MemoState[T],
+ rootGroupId: Int): GraphvizVisualizer[T] = {
+ val fakeBestFinder = new FakeBestFinder[T](cbo, memoState.allGroups())
+ val fakeBest = fakeBestFinder.bestOf(rootGroupId)
+ new GraphvizVisualizer(cbo, memoState, fakeBest)
+ }
+
+ def apply[T <: AnyRef](
+ cbo: Cbo[T],
+ memoState: MemoState[T],
+ best: Best[T]): GraphvizVisualizer[T] = {
+ new GraphvizVisualizer(cbo, memoState, best)
+ }
+}
diff --git a/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/CboOperationSuite.scala b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/CboOperationSuite.scala
new file mode 100644
index 000000000000..2da8a1b989a8
--- /dev/null
+++ b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/CboOperationSuite.scala
@@ -0,0 +1,461 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo
+
+import io.glutenproject.cbo.path.CboPath
+import io.glutenproject.cbo.property.PropertySet
+import io.glutenproject.cbo.rule.{CboRule, Shape, Shapes}
+
+import org.scalatest.funsuite.AnyFunSuite
+
+class CboOperationSuite extends AnyFunSuite {
+ import CboOperationSuite._
+
+ test(s"Rule invocation count - depth 2, 1") {
+ val l2l2 = new LeafToLeaf2()
+ val u2u2 = new UnaryToUnary2()
+
+ object Unary2Unary2ToUnary3 extends CboRule[TestNode] {
+ var invocationCount: Int = 0
+ var effectiveInvocationCount: Int = 0
+ override def shift(node: TestNode): Iterable[TestNode] = {
+ invocationCount += 1
+ node match {
+ case Unary2(cost1, Unary2(cost2, child)) =>
+ effectiveInvocationCount += 1
+ List(Unary3(cost1 + cost2 - 1, child))
+ case other => List.empty
+ }
+ }
+
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(2)
+ }
+
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.reuse(List(l2l2, u2u2, Unary2Unary2ToUnary3)))
+ val plan = Unary(50, Unary2(50, Unary2(50, Unary2(50, Leaf(30)))))
+ val planner = cbo.newPlanner(plan)
+ val optimized = planner.plan()
+
+ assert(Unary2Unary2ToUnary3.invocationCount == 14)
+ assert(Unary2Unary2ToUnary3.effectiveInvocationCount == 3)
+ assert(optimized == Unary3(98, Unary3(99, Leaf2(29))))
+ }
+
+ test("Rule invocation count - depth 2, 2") {
+ val l2l2 = new LeafToLeaf2()
+ val u2u2 = new UnaryToUnary2()
+ val u2u22u3 = new Unary2Unary2ToUnary3()
+
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.reuse(List(l2l2, u2u2, u2u22u3)))
+ val plan = Unary(50, Unary2(50, Unary2(50, Unary2(50, Leaf(30)))))
+ val planner = cbo.newPlanner(plan)
+ val optimized = planner.plan()
+
+ assert(l2l2.invocationCount == 10)
+ assert(l2l2.effectiveInvocationCount == 1)
+ assert(u2u2.invocationCount == 10)
+ assert(u2u2.effectiveInvocationCount == 1)
+ assert(u2u22u3.invocationCount == 14)
+ assert(u2u22u3.effectiveInvocationCount == 3)
+ assert(optimized == Unary3(98, Unary3(99, Leaf2(29))))
+ }
+
+ test("Plan manipulation count - depth 1") {
+ val l2l2 = new LeafToLeaf2()
+ val u2u2 = new UnaryToUnary2()
+
+ val planModel = new PlanModelWithStats(PlanModelImpl)
+
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ planModel,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.reuse(List(l2l2, u2u2)))
+ val plan = Unary(50, Leaf(30))
+ val planner = cbo.newPlanner(plan)
+ val optimized = planner.plan()
+ assert(optimized == Unary2(49, Leaf2(29)))
+
+ planModel.assertPlanOpsLte((200, 50, 50, 50))
+
+ val state = planner.newState()
+ val allPaths = state.memoState().collectAllPaths(CboPath.INF_DEPTH).toSeq
+ val distinctPathCount = allPaths.distinct.size
+ val pathCount = allPaths.size
+ assert(distinctPathCount == pathCount)
+ assert(pathCount == 8)
+ }
+
+ test("Plan manipulation count - depth 2") {
+ val l2l2 = new LeafToLeaf2()
+ val u2u2 = new UnaryToUnary2()
+ val u2u22u3 = new Unary2Unary2ToUnary3()
+
+ val planModel = new PlanModelWithStats(PlanModelImpl)
+
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ planModel,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.reuse(List(l2l2, u2u2, u2u22u3)))
+ val plan = Unary(50, Unary2(50, Unary2(50, Unary2(50, Leaf(30)))))
+ val planner = cbo.newPlanner(plan)
+ val optimized = planner.plan()
+ assert(optimized == Unary3(98, Unary3(99, Leaf2(29))))
+
+ planModel.assertPlanOpsLte((800, 300, 300, 200))
+
+ val state = planner.newState()
+ val allPaths = state.memoState().collectAllPaths(CboPath.INF_DEPTH).toSeq
+ val distinctPathCount = allPaths.distinct.size
+ val pathCount = allPaths.size
+ assert(distinctPathCount == pathCount)
+ assert(pathCount == 58)
+ }
+
+ test("Plan manipulation count - depth 5") {
+ val rule = new CboRule[TestNode] {
+ override def shift(node: TestNode): Iterable[TestNode] = node match {
+ case Unary(c1, Unary(c2, Unary(c3, Unary(c4, Unary(c5, child))))) =>
+ List(Unary2(c1, Unary2(c2, Unary2(c3 - 6, Unary2(c4, Unary2(c5, child))))))
+ case other => List.empty
+ }
+
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(5)
+ }
+
+ val planModel = new PlanModelWithStats(PlanModelImpl)
+
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ planModel,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.reuse(List(new UnaryToUnary2(), new LeafToLeaf2(), rule)))
+ val plan = Unary(
+ 50,
+ Unary(
+ 50,
+ Unary(
+ 50,
+ Unary(50, Unary(50, Unary(50, Unary(50, Unary(50, Unary(50, Unary(50, Leaf(30)))))))))))
+ val planner = cbo.newPlanner(plan)
+ val optimized = planner.plan()
+ assert(
+ optimized == Unary2(
+ 50,
+ Unary2(
+ 50,
+ Unary2(
+ 44,
+ Unary2(
+ 50,
+ Unary2(50, Unary2(50, Unary2(50, Unary2(44, Unary2(50, Unary2(50, Leaf2(29))))))))))))
+
+ planModel.assertPlanOpsLte((20000, 10000, 3000, 3000))
+
+ val state = planner.newState()
+ val allPaths = state.memoState().collectAllPaths(CboPath.INF_DEPTH).toSeq
+ val distinctPathCount = allPaths.distinct.size
+ val pathCount = allPaths.size
+ assert(distinctPathCount == pathCount)
+ assert(pathCount == 10865)
+ }
+
+ test("Cost evaluation count - base") {
+ val costModel = new CostModelWithStats(CostModelImpl)
+
+ val cbo =
+ Cbo[TestNode](
+ costModel,
+ PlanModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.reuse(List(new UnaryToUnary2, new Unary2ToUnary3)))
+ val plan = Unary(
+ 50,
+ Unary(
+ 50,
+ Unary(
+ 50,
+ Unary(50, Unary(50, Unary(50, Unary(50, Unary(50, Unary(50, Unary(50, Leaf(30)))))))))))
+ val planner = cbo.newPlanner(plan)
+ val optimized = planner.plan()
+ assert(
+ optimized == Unary3(
+ 48,
+ Unary3(
+ 48,
+ Unary3(
+ 48,
+ Unary3(
+ 48,
+ Unary3(48, Unary3(48, Unary3(48, Unary3(48, Unary3(48, Unary3(48, Leaf(30))))))))))))
+ assert(costModel.costOfCount == 32) // TODO reduce this for performance
+ assert(costModel.costCompareCount == 20) // TODO reduce this for performance
+ }
+
+ test("Cost evaluation count - max cost") {
+ val costModelPruned = new CostModel[TestNode] {
+
+ override def costOf(node: TestNode): Cost = {
+ node match {
+ case ll: LeafLike =>
+ CostModelImpl.costOf(ll)
+ case ul: UnaryLike if ul.child.isInstanceOf[LeafLike] =>
+ CostModelImpl.costOf(ul)
+ case u @ Unary(_, Unary(_, _)) =>
+ CostModelImpl.costOf(u)
+ case u @ Unary2(_, Unary2(_, _)) =>
+ CostModelImpl.costOf(u)
+ case u @ Unary3(_, Unary3(_, _)) =>
+ CostModelImpl.costOf(u)
+ case _ =>
+ // By returning a maximum cost, patterns other than the above accepted patterns
+ // should be pruned.
+ LongCost(Long.MaxValue)
+ }
+ }
+
+ override def costComparator(): Ordering[Cost] = {
+ CostModelImpl.costComparator()
+ }
+
+ override def makeInfCost(): Cost = CostModelImpl.makeInfCost()
+ }
+
+ val costModel = new CostModelWithStats(costModelPruned)
+
+ val cbo =
+ Cbo[TestNode](
+ costModel,
+ PlanModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.reuse(List(new UnaryToUnary2, new Unary2ToUnary3)))
+ val plan = Unary(
+ 50,
+ Unary(
+ 50,
+ Unary(
+ 50,
+ Unary(50, Unary(50, Unary(50, Unary(50, Unary(50, Unary(50, Unary(50, Leaf(30)))))))))))
+ val planner = cbo.newPlanner(plan)
+ val optimized = planner.plan()
+ assert(
+ optimized == Unary3(
+ 48,
+ Unary3(
+ 48,
+ Unary3(
+ 48,
+ Unary3(
+ 48,
+ Unary3(48, Unary3(48, Unary3(48, Unary3(48, Unary3(48, Unary3(48, Leaf(30))))))))))))
+ assert(costModel.costOfCount == 32) // TODO reduce this for performance
+ assert(costModel.costCompareCount == 20) // TODO reduce this for performance
+ }
+}
+
+object CboOperationSuite extends CboSuiteBase {
+
+ case class Unary(override val selfCost: Long, override val child: TestNode) extends UnaryLike {
+ override def withNewChildren(child: TestNode): UnaryLike = copy(child = child)
+ }
+
+ case class Unary2(override val selfCost: Long, override val child: TestNode) extends UnaryLike {
+ override def withNewChildren(child: TestNode): UnaryLike = copy(child = child)
+ }
+
+ case class Unary3(override val selfCost: Long, override val child: TestNode) extends UnaryLike {
+ override def withNewChildren(child: TestNode): UnaryLike = copy(child = child)
+ }
+
+ case class Leaf(override val selfCost: Long) extends LeafLike {
+ override def makeCopy(): LeafLike = copy()
+ }
+
+ case class Leaf2(override val selfCost: Long) extends LeafLike {
+ override def makeCopy(): LeafLike = copy()
+ }
+
+ class LeafToLeaf2 extends CboRule[TestNode] {
+ var invocationCount: Int = 0
+ var effectiveInvocationCount: Int = 0
+
+ override def shift(node: TestNode): Iterable[TestNode] = {
+ invocationCount += 1
+ node match {
+ case Leaf(cost) =>
+ effectiveInvocationCount += 1
+ List(Leaf2(cost - 1))
+ case other => List.empty
+ }
+ }
+
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(1)
+ }
+
+ class UnaryToUnary2 extends CboRule[TestNode] {
+ var invocationCount: Int = 0
+ var effectiveInvocationCount: Int = 0
+
+ override def shift(node: TestNode): Iterable[TestNode] = {
+ invocationCount += 1
+ node match {
+ case Unary(cost, child) =>
+ effectiveInvocationCount += 1
+ List(Unary2(cost - 1, child))
+ case other => List.empty
+ }
+ }
+
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(1)
+ }
+
+ class Unary2ToUnary3 extends CboRule[TestNode] {
+ var invocationCount: Int = 0
+ var effectiveInvocationCount: Int = 0
+
+ override def shift(node: TestNode): Iterable[TestNode] = {
+ invocationCount += 1
+ node match {
+ case Unary2(cost, child) =>
+ effectiveInvocationCount += 1
+ List(Unary3(cost - 1, child))
+ case other => List.empty
+ }
+ }
+
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(1)
+ }
+
+ class Unary2Unary2ToUnary3 extends CboRule[TestNode] {
+ var invocationCount: Int = 0
+ var effectiveInvocationCount: Int = 0
+
+ override def shift(node: TestNode): Iterable[TestNode] = {
+ invocationCount += 1
+ node match {
+ case Unary2(cost1, Unary2(cost2, child)) =>
+ effectiveInvocationCount += 1
+ List(Unary3(cost1 + cost2 - 1, child))
+ case other => List.empty
+ }
+ }
+
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(2)
+ }
+
+ class PlanModelWithStats[T <: AnyRef](delegated: PlanModel[T]) extends PlanModel[T] {
+ var childrenOfCount = 0
+ var withNewChildrenCount = 0
+ var hashCodeCount = 0
+ var equalsCount = 0
+ var newGroupLeafCount = 0
+ var isGroupLeafCount = 0
+ var getGroupIdCount = 0
+
+ override def childrenOf(node: T): Seq[T] = {
+ childrenOfCount += 1
+ delegated.childrenOf(node)
+ }
+ override def withNewChildren(node: T, children: Seq[T]): T = {
+ withNewChildrenCount += 1
+ delegated.withNewChildren(node, children)
+ }
+ override def hashCode(node: T): Int = {
+ hashCodeCount += 1
+ delegated.hashCode(node)
+ }
+ override def equals(one: T, other: T): Boolean = {
+ equalsCount += 1
+ delegated.equals(one, other)
+ }
+ override def newGroupLeaf(groupId: Int, propSet: PropertySet[T]): T = {
+ newGroupLeafCount += 1
+ delegated.newGroupLeaf(groupId, propSet)
+ }
+ override def isGroupLeaf(node: T): Boolean = {
+ isGroupLeafCount += 1
+ delegated.isGroupLeaf(node)
+ }
+ override def getGroupId(node: T): Int = {
+ getGroupIdCount += 1
+ delegated.getGroupId(node)
+ }
+ }
+
+ private object PlanModelWithStats {
+ implicit class PlanModelWithStatsImplicits[T <: AnyRef](model: PlanModelWithStats[T]) {
+ def assertPlanOpsLte(bounds: (Int, Int, Int, Int)): Unit = {
+ val actual = (
+ model.childrenOfCount,
+ model.withNewChildrenCount,
+ model.hashCodeCount,
+ model.equalsCount)
+ assert(
+ List(actual._1, actual._2, actual._3, actual._4)
+ .zip(List(bounds._1, bounds._2, bounds._3, bounds._4))
+ .forall {
+ case (count, bound) =>
+ count <= bound
+ },
+ s"Assertion failed. The expected bounds: $bounds, actual: " +
+ s"$actual"
+ )
+ }
+ }
+ }
+
+ class CostModelWithStats[T <: AnyRef](delegated: CostModel[T]) extends CostModel[T] {
+ var costOfCount = 0
+ var costCompareCount = 0
+
+ override def costOf(node: T): Cost = {
+ costOfCount += 1
+ delegated.costOf(node)
+ }
+ override def costComparator(): Ordering[Cost] = {
+ new Ordering[Cost] {
+ override def compare(x: Cost, y: Cost): Int = {
+ costCompareCount += 1
+ delegated.costComparator().compare(x, y)
+ }
+ }
+ }
+
+ override def makeInfCost(): Cost = delegated.makeInfCost()
+ }
+}
diff --git a/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/CboPropertySuite.scala b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/CboPropertySuite.scala
new file mode 100644
index 000000000000..0225601a253b
--- /dev/null
+++ b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/CboPropertySuite.scala
@@ -0,0 +1,720 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo
+
+import io.glutenproject.cbo.Best.BestNotFoundException
+import io.glutenproject.cbo.CboConfig.PlannerType
+import io.glutenproject.cbo.property.PropertySet
+import io.glutenproject.cbo.rule.{CboRule, Shape, Shapes}
+
+import org.scalatest.funsuite.AnyFunSuite
+
+class ExhaustivePlannerPropertySuite extends CboPropertySuite {
+ override protected def conf: CboConfig = CboConfig(plannerType = PlannerType.Exhaustive)
+}
+
+class DpPlannerPropertySuite extends CboPropertySuite {
+ override protected def conf: CboConfig = CboConfig(plannerType = PlannerType.Dp)
+}
+
+abstract class CboPropertySuite extends AnyFunSuite {
+ import CboPropertySuite._
+
+ protected def conf: CboConfig
+
+ test(s"Get property") {
+ val leaf = PLeaf(10, DummyProperty(0))
+ val unary = PUnary(5, DummyProperty(0), leaf)
+ val binary = PBinary(5, DummyProperty(0), leaf, unary)
+
+ val model = DummyPropertyModel
+ val propDefs = model.propertyDefs
+
+ assert(propDefs.size === 1)
+ assert(propDefs.head.getProperty(leaf) === DummyProperty(0))
+ assert(propDefs.head.getProperty(unary) === DummyProperty(0))
+ assert(propDefs.head.getProperty(binary) === DummyProperty(0))
+ assert(propDefs.head.getChildrenConstraints(DummyProperty(0), leaf) === Seq.empty)
+ assert(propDefs.head.getChildrenConstraints(DummyProperty(0), unary) === Seq(DummyProperty(0)))
+ assert(propDefs.head
+ .getChildrenConstraints(DummyProperty(0), binary) === Seq(DummyProperty(0), DummyProperty(0)))
+ }
+
+ test(s"Cannot enforce property") {
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ NodeTypePropertyModelWithOutEnforcerRules,
+ ExplainImpl,
+ CboRule.Factory.none())
+ .withNewConfig(_ => conf)
+ val plan = TypedUnary(TypeA, 10, TypedLeaf(TypeA, 10))
+ val planner = cbo.newPlanner(plan, PropertySet(Seq(TypeB)))
+ assertThrows[BestNotFoundException] {
+ planner.plan()
+ }
+ }
+
+ test(s"Property enforcement - A to B") {
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ NodeTypePropertyModel,
+ ExplainImpl,
+ CboRule.Factory.none())
+ .withNewConfig(_ => conf)
+
+ val plan =
+ TypedBinary(TypeA, 5, TypedUnary(TypeA, 10, TypedLeaf(TypeA, 10)), TypedLeaf(TypeA, 10))
+ val planner = cbo.newPlanner(plan, PropertySet(Seq(TypeB)))
+ val out = planner.plan()
+ assert(out == TypeEnforcer(TypeB, 1, plan))
+ }
+
+ test(s"Property convert - (A, B)") {
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ NodeTypePropertyModel,
+ ExplainImpl,
+ CboRule.Factory.reuse(List(ReplaceByTypeARule, ReplaceByTypeBRule)))
+ .withNewConfig(_ => conf)
+ val plan =
+ TypedBinary(TypeA, 5, TypedUnary(TypeA, 10, TypedLeaf(TypeA, 10)), TypedLeaf(TypeA, 10))
+ val planner = cbo.newPlanner(plan, PropertySet(Seq(TypeB)))
+ val out = planner.plan()
+ assert(
+ out == TypedBinary(
+ TypeB,
+ 5,
+ TypedUnary(TypeB, 10, TypedLeaf(TypeB, 10)),
+ TypedLeaf(TypeB, 10)))
+ }
+
+ ignore(s"Memo cache hit - (A, B)") {
+ object ReplaceLeafAByLeafBRule extends CboRule[TestNode] {
+ override def shift(node: TestNode): Iterable[TestNode] = {
+ node match {
+ case TypedLeaf(TypeA, cost) => List(TypedLeaf(TypeB, cost - 1))
+ case other => List.empty
+ }
+ }
+
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(1)
+ }
+
+ object HitCacheOp extends CboRule[TestNode] {
+ override def shift(node: TestNode): Iterable[TestNode] = {
+ node match {
+ case PassNodeType(10, TypedLeaf(TypeA, 10)) =>
+ List(TypedUnary(TypeB, 15, PassNodeType(10, TypedLeaf(TypeB, 9))))
+ case other => List.empty
+ }
+ }
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(2)
+ }
+
+ object FinalOp extends CboRule[TestNode] {
+
+ override def shift(node: TestNode): Iterable[TestNode] = {
+ node match {
+ case TypedUnary(TypeB, 15, PassNodeType(10, TypedLeaf(TypeB, 9))) =>
+ List(TypedLeaf(TypeA, 1))
+ case other => List.empty
+ }
+ }
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(3)
+ }
+
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ NodeTypePropertyModelWithOutEnforcerRules,
+ ExplainImpl,
+ CboRule.Factory.reuse(List(ReplaceLeafAByLeafBRule, HitCacheOp, FinalOp))
+ )
+ .withNewConfig(_ => conf)
+
+ val plan = PassNodeType(10, TypedLeaf(TypeA, 10))
+ val planner = cbo.newPlanner(plan, PropertySet(Seq(TypeA)))
+ val out = planner.plan()
+ assert(out == TypedLeaf(TypeA, 1))
+
+ // FIXME: Cluster 2 and 1 are currently able to merge but it's better to
+ // have them identified as the same right after HitCacheOp is applied
+ val clusterCount = planner.newState().memoState().allClusters().size
+ assert(clusterCount == 2)
+ }
+
+ test(s"Property propagation - (A, B)") {
+ // The propagation is expected to be done by built-in enforcer rule.
+ object ReplaceLeafAByLeafBRule extends CboRule[TestNode] {
+ override def shift(node: TestNode): Iterable[TestNode] = {
+ node match {
+ case TypedLeaf(TypeA, cost) => List(TypedLeaf(TypeB, cost - 1))
+ case other => List.empty
+ }
+ }
+
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(1)
+ }
+
+ object ReplaceUnaryBByUnaryARule extends CboRule[TestNode] {
+ override def shift(node: TestNode): Iterable[TestNode] = {
+ node match {
+ case TypedUnary(TypeB, cost, child) => List(TypedUnary(TypeA, cost - 2, child))
+ case other => List.empty
+ }
+ }
+
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(1)
+ }
+
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ NodeTypePropertyModelWithOutEnforcerRules,
+ ExplainImpl,
+ CboRule.Factory.reuse(List(ReplaceLeafAByLeafBRule, ReplaceUnaryBByUnaryARule))
+ )
+ .withNewConfig(_ => conf)
+ val sub =
+ PassNodeType(5, TypedLeaf(TypeA, 10))
+ val plan = TypedUnary(TypeB, 10, sub)
+ val planner = cbo.newPlanner(plan, PropertySet(Seq(TypeAny)))
+ val out = planner.plan()
+
+ assert(out == TypedUnary(TypeA, 8, PassNodeType(5, TypedLeaf(TypeA, 10))))
+ }
+
+ test(s"Property convert - (A, B), alternative conventions") {
+ object ConvertEnforcerAndTypeAToTypeB extends CboRule[TestNode] {
+ override def shift(node: TestNode): Iterable[TestNode] = node match {
+ case TypeEnforcer(TypeB, _, TypedBinary(TypeA, 5, left, right)) =>
+ List(TypedBinary(TypeB, 0, left, right))
+ case _ => List.empty
+ }
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(2)
+ }
+
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ NodeTypePropertyModel,
+ ExplainImpl,
+ CboRule.Factory.reuse(List(ConvertEnforcerAndTypeAToTypeB)))
+ .withNewConfig(_ => conf)
+ val plan =
+ TypedBinary(TypeA, 5, TypedUnary(TypeA, 10, TypedLeaf(TypeA, 10)), TypedLeaf(TypeA, 10))
+ val planner = cbo.newPlanner(
+ plan,
+ PropertySet(Seq(TypeAny)),
+ List(PropertySet(Seq(TypeB)), PropertySet(Seq(TypeC))))
+ val out = planner.plan()
+ assert(
+ out == TypedBinary(
+ TypeB,
+ 0,
+ TypeEnforcer(TypeB, 1, TypedUnary(TypeA, 10, TypedLeaf(TypeA, 10))),
+ TypeEnforcer(TypeB, 1, TypedLeaf(TypeA, 10))))
+ assert(planner.newState().memoState().allGroups().size == 9)
+ }
+
+ test(s"Property convert - (A, B), Unary only has TypeA") {
+ object ReplaceNonUnaryByTypeBRule extends CboRule[TestNode] {
+ override def shift(node: TestNode): Iterable[TestNode] = {
+ node match {
+ case TypedLeaf(_, cost) => List(TypedLeaf(TypeB, cost))
+ case TypedBinary(_, cost, left, right) => List(TypedBinary(TypeB, cost, left, right))
+ case other => List.empty
+ }
+ }
+
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(1)
+ }
+
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ NodeTypePropertyModel,
+ ExplainImpl,
+ CboRule.Factory.reuse(List(ReplaceByTypeARule, ReplaceNonUnaryByTypeBRule)))
+ .withNewConfig(_ => conf)
+ val plan =
+ TypedBinary(TypeA, 5, TypedUnary(TypeA, 10, TypedLeaf(TypeA, 10)), TypedLeaf(TypeA, 10))
+ val planner = cbo.newPlanner(plan, PropertySet(Seq(TypeB)))
+ val out = planner.plan()
+ assert(
+ out == TypeEnforcer(
+ TypeB,
+ 1,
+ TypedBinary(
+ TypeA,
+ 5,
+ TypedUnary(TypeA, 10, TypedLeaf(TypeA, 10)),
+ TypedLeaf(TypeA, 10))) || out == TypedBinary(
+ TypeB,
+ 5,
+ TypeEnforcer(TypeB, 1, TypedUnary(TypeA, 10, TypedLeaf(TypeA, 10))),
+ TypedLeaf(TypeB, 10)))
+ }
+
+ test(s"Property convert - (A, B, C), TypeC has lowest cost") {
+ object ReduceTypeBCost extends CboRule[TestNode] {
+ override def shift(node: TestNode): Iterable[TestNode] = {
+ node match {
+ case TypedLeaf(TypeB, _) => List(TypedLeaf(TypeB, 5))
+ case TypedUnary(TypeB, _, child) => List(TypedUnary(TypeB, 5, child))
+ case TypedBinary(TypeB, _, left, right) => List(TypedBinary(TypeB, 5, left, right))
+ case other => List.empty
+ }
+ }
+
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(1)
+ }
+
+ object ConvertUnaryTypeBToTypeC extends CboRule[TestNode] {
+ override def shift(node: TestNode): Iterable[TestNode] = node match {
+ case TypedUnary(TypeB, _, child) => List(TypedUnary(TypeC, 0, child))
+ case other => List.empty
+ }
+
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(1)
+ }
+
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ NodeTypePropertyModel,
+ ExplainImpl,
+ CboRule.Factory.reuse(List(ReduceTypeBCost, ConvertUnaryTypeBToTypeC)))
+ .withNewConfig(_ => conf)
+
+ val plan =
+ TypedUnary(TypeB, 10, TypedLeaf(TypeA, 20))
+ val planner = cbo.newPlanner(plan, PropertySet(Seq(TypeB)))
+ val out = planner.plan()
+ assert(
+ out == TypeEnforcer(
+ TypeB,
+ 1,
+ TypedUnary(TypeC, 0, TypeEnforcer(TypeC, 1, TypedLeaf(TypeA, 20)))))
+ }
+
+ test(
+ s"Property convert - (A, B, C), TypeC has lowest cost, binary root," +
+ s" right enforcer added after left is explored, disordered group creation") {
+
+ object RightOp extends CboRule[TestNode] {
+ // Let the right child tree add group A to the leaf.
+ override def shift(node: TestNode): Iterable[TestNode] = node match {
+ case TypedUnary(TypeA, 15, TypedLeaf(TypeA, 20)) =>
+ // This creates enforcer at leaf's cluster.
+ List(TypedUnary(TypeB, 15, TypedLeaf(TypeA, 20)))
+ case TypeEnforcer(TypeB, 1, TypedLeaf(TypeA, 20)) =>
+ // Cost is high, so won't be chosen by right tree.
+ List(TypedLeaf(TypeA, 100))
+ case other => List.empty
+ }
+
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(2)
+ }
+
+ object LeftOp extends CboRule[TestNode] {
+ // The the left child tree should be aware of the enforcer and continue exploration.
+ override def shift(node: TestNode): Iterable[TestNode] = node match {
+ case TypedUnary(TypeA, 10, TypedLeaf(TypeA, 100)) =>
+ // The leaf was created by right OP.
+ List(TypedLeaf(TypeC, 0))
+ case other => List.empty
+ }
+
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(2)
+ }
+
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ NodeTypePropertyModel,
+ ExplainImpl,
+ CboRule.Factory.reuse(List(LeftOp, RightOp))
+ )
+ .withNewConfig(_ => conf)
+
+ val left = TypedUnary(TypeA, 10, TypedLeaf(TypeA, 20))
+ val right = TypeEnforcer(TypeA, 1, TypedUnary(TypeA, 15, TypedLeaf(TypeA, 20)))
+ val plan = TypedBinary(TypeA, 10, left, right)
+ val planner = cbo.newPlanner(plan, PropertySet(Seq(TypeA)))
+ val out = planner.plan()
+ assert(
+ out == TypedBinary(
+ TypeA,
+ 10,
+ TypeEnforcer(TypeA, 1, TypedLeaf(TypeC, 0)),
+ TypeEnforcer(TypeA, 1, TypedUnary(TypeA, 15, TypedLeaf(TypeA, 20)))))
+ }
+
+ test(
+ s"Property convert - (A, B, C), TypeC has lowest cost, binary root," +
+ s" right enforcer added after left is explored, merge") {
+ object ConvertTypeBEnforcerAndLeafToTypeC extends CboRule[TestNode] {
+ override def shift(node: TestNode): Iterable[TestNode] = node match {
+ case TypeEnforcer(TypeB, _, _ @TypedLeaf(_, _)) =>
+ List(TypedLeaf(TypeC, 0))
+ case other => List.empty
+ }
+
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(2)
+ }
+
+ object ConvertTypeATypeCToTypeC extends CboRule[TestNode] {
+ override def shift(node: TestNode): Iterable[TestNode] = node match {
+ case TypedUnary(TypeA, _, TypeEnforcer(TypeA, _, _ @TypedLeaf(TypeC, _))) =>
+ List(TypedLeaf(TypeC, 0))
+ case other => List.empty
+ }
+
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(3)
+ }
+
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ NodeTypePropertyModel,
+ ExplainImpl,
+ CboRule.Factory.reuse(List(ConvertTypeBEnforcerAndLeafToTypeC, ConvertTypeATypeCToTypeC))
+ )
+ .withNewConfig(_ => conf)
+
+ val left =
+ TypedUnary(TypeA, 10, TypedUnary(TypeA, 10, TypedLeaf(TypeA, 20)))
+ val right =
+ TypedUnary(TypeB, 15, TypedUnary(TypeB, 15, TypedLeaf(TypeA, 20)))
+ val plan = TypedBinary(TypeB, 10, left, right)
+ val planner = cbo.newPlanner(plan, PropertySet(Seq(TypeB)))
+ val out = planner.plan()
+ assert(
+ out == TypedBinary(
+ TypeB,
+ 10,
+ TypeEnforcer(TypeB, 1, TypedLeaf(TypeC, 0)),
+ TypedUnary(TypeB, 15, TypedUnary(TypeB, 15, TypeEnforcer(TypeB, 1, TypedLeaf(TypeC, 0))))))
+ }
+
+ test(s"Property convert - (A, B), Unary only has TypeA, TypeB has lowest cost") {
+ // TODO: Apply enforce rules on low-cost nodes to propagate it to other groups.
+ object ReduceTypeBCost extends CboRule[TestNode] {
+ override def shift(node: TestNode): Iterable[TestNode] = {
+ node match {
+ case TypedLeaf(TypeB, _) => List(TypedLeaf(TypeB, 0))
+ case TypedUnary(TypeB, _, child) => List(TypedUnary(TypeB, 0, child))
+ case TypedBinary(TypeB, _, left, right) => List(TypedBinary(TypeB, 0, left, right))
+ case other => List.empty
+ }
+ }
+
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(1)
+ }
+
+ object ReplaceNonUnaryByTypeBRule extends CboRule[TestNode] {
+ override def shift(node: TestNode): Iterable[TestNode] = {
+ node match {
+ case TypedLeaf(_, cost) => List(TypedLeaf(TypeB, cost))
+ case TypedBinary(_, cost, left, right) => List(TypedBinary(TypeB, cost, left, right))
+ case other => List.empty
+ }
+ }
+
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(1)
+ }
+
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ NodeTypePropertyModel,
+ ExplainImpl,
+ CboRule.Factory.reuse(List(ReplaceNonUnaryByTypeBRule, ReduceTypeBCost)))
+ .withNewConfig(_ => conf)
+ val plan =
+ TypedBinary(TypeA, 5, TypedUnary(TypeA, 10, TypedLeaf(TypeA, 10)), TypedLeaf(TypeA, 10))
+ val planner = cbo.newPlanner(plan, PropertySet(Seq(TypeB)))
+ val out = planner.plan()
+ assert(
+ out == TypedBinary(
+ TypeB,
+ 0,
+ TypeEnforcer(TypeB, 1, TypedUnary(TypeA, 10, TypeEnforcer(TypeA, 1, TypedLeaf(TypeB, 0)))),
+ TypedLeaf(TypeB, 0)))
+ }
+}
+
+object CboPropertySuite extends CboSuiteBase {
+
+ case class NoopEnforcerRule[T <: AnyRef]() extends CboRule[T] {
+ override def shift(node: T): Iterable[T] = List.empty
+ override def shape(): Shape[T] = Shapes.none()
+ }
+
+ // Dummy property model
+
+ case class DummyProperty(id: Int) extends Property[TestNode] {
+ override def satisfies(other: Property[TestNode]): Boolean = {
+ other match {
+ case DummyProperty(otherId) =>
+ // Higher ID satisfies lower IDs.
+ id >= otherId
+ case _ => throw new IllegalStateException()
+ }
+ }
+
+ override def definition(): PropertyDef[TestNode, DummyProperty] = {
+ DummyPropertyDef
+ }
+ }
+
+ case class PUnary(override val selfCost: Long, prop: DummyProperty, override val child: TestNode)
+ extends UnaryLike {
+ override def withNewChildren(child: TestNode): UnaryLike = copy(child = child)
+ }
+
+ case class PLeaf(override val selfCost: Long, prop: DummyProperty) extends LeafLike {
+ override def makeCopy(): LeafLike = copy()
+ }
+
+ case class PBinary(
+ override val selfCost: Long,
+ prop: DummyProperty,
+ override val left: TestNode,
+ override val right: TestNode)
+ extends BinaryLike {
+ override def withNewChildren(left: TestNode, right: TestNode): BinaryLike =
+ copy(left = left, right = right)
+ }
+
+ object DummyPropertyDef extends PropertyDef[TestNode, DummyProperty] {
+ override def getProperty(plan: TestNode): DummyProperty = {
+ plan match {
+ case Group(_, _) => throw new IllegalStateException()
+ case PUnary(_, prop, _) => prop
+ case PLeaf(_, prop) => prop
+ case PBinary(_, prop, _, _) => prop
+ case _ => DummyProperty(-1)
+ }
+ }
+
+ override def getChildrenConstraints(
+ constraint: Property[TestNode],
+ plan: TestNode): Seq[DummyProperty] = {
+ plan match {
+ case PUnary(_, _, _) => Seq(DummyProperty(0))
+ case PLeaf(_, _) => Seq.empty
+ case PBinary(_, _, _, _) => Seq(DummyProperty(0), DummyProperty(0))
+ case _ => throw new IllegalStateException()
+ }
+ }
+ }
+
+ object DummyPropertyModel extends PropertyModel[TestNode] {
+ override def propertyDefs: Seq[PropertyDef[TestNode, _ <: Property[TestNode]]] = Seq(
+ DummyPropertyDef)
+
+ override def newEnforcerRuleFactory(propertyDef: PropertyDef[TestNode, _ <: Property[TestNode]])
+ : EnforcerRuleFactory[TestNode] = (constraint: Property[TestNode]) => List.empty
+ }
+
+ // Node type property model
+
+ trait TypedNode extends TestNode {
+ def nodeType: NodeType
+ }
+
+ case class TypedLeaf(override val nodeType: NodeType, override val selfCost: Long)
+ extends LeafLike
+ with TypedNode {
+ override def makeCopy(): LeafLike = copy()
+ }
+
+ case class TypedUnary(
+ override val nodeType: NodeType,
+ override val selfCost: Long,
+ override val child: TestNode)
+ extends UnaryLike
+ with TypedNode {
+ override def withNewChildren(child: TestNode): UnaryLike = copy(child = child)
+ }
+
+ case class TypedBinary(
+ override val nodeType: NodeType,
+ override val selfCost: Long,
+ override val left: TestNode,
+ override val right: TestNode)
+ extends BinaryLike
+ with TypedNode {
+ override def withNewChildren(left: TestNode, right: TestNode): BinaryLike =
+ copy(left = left, right = right)
+ }
+
+ case class TypeEnforcer(
+ override val nodeType: NodeType,
+ override val selfCost: Long,
+ override val child: TestNode)
+ extends UnaryLike
+ with TypedNode {
+ override def withNewChildren(child: TestNode): UnaryLike = copy(child = child)
+ }
+
+ case class PassNodeType(override val selfCost: Long, child: TestNode) extends TypedNode {
+ override def nodeType: NodeType = child match {
+ case n: TypedNode => n.nodeType
+ case g: Group => g.propSet.get(NodeTypeDef)
+ case _ => throw new IllegalStateException()
+ }
+
+ override def children(): Seq[TestNode] = List(child)
+ override def withNewChildren(children: Seq[TestNode]): TestNode = copy(selfCost, children.head)
+ }
+
+ case class NodeTypeEnforcerRule(reqType: NodeType) extends CboRule[TestNode] {
+ override def shift(node: TestNode): Iterable[TestNode] = {
+ node match {
+ case typed: TypedNode if typed.nodeType.satisfies(reqType) => List(typed)
+ case typed: TypedNode => List(TypeEnforcer(reqType, 1, typed))
+ case _ => throw new IllegalStateException()
+ }
+ }
+
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(1)
+ }
+
+ object ReplaceByTypeARule extends CboRule[TestNode] {
+ override def shift(node: TestNode): Iterable[TestNode] = {
+ node match {
+ case TypedLeaf(_, cost) => List(TypedLeaf(TypeA, cost))
+ case TypedUnary(_, cost, child) => List(TypedUnary(TypeA, cost, child))
+ case TypedBinary(_, cost, left, right) => List(TypedBinary(TypeA, cost, left, right))
+ case other => List.empty
+ }
+ }
+
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(1)
+ }
+
+ object ReplaceByTypeBRule extends CboRule[TestNode] {
+ override def shift(node: TestNode): Iterable[TestNode] = {
+ node match {
+ case TypedLeaf(_, cost) => List(TypedLeaf(TypeB, cost))
+ case TypedUnary(_, cost, child) => List(TypedUnary(TypeB, cost, child))
+ case TypedBinary(_, cost, left, right) => List(TypedBinary(TypeB, cost, left, right))
+ case other => List.empty
+ }
+ }
+
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(1)
+ }
+
+ object NodeTypeDef extends PropertyDef[TestNode, NodeType] {
+ override def getProperty(plan: TestNode): NodeType = plan match {
+ case typed: TypedNode => typed.nodeType
+ case _ => throw new IllegalStateException()
+ }
+
+ override def getChildrenConstraints(
+ constraint: Property[TestNode],
+ plan: TestNode): Seq[NodeType] = plan match {
+ case TypedLeaf(_, _) => Seq.empty
+ case TypedUnary(t, _, _) => Seq(t)
+ case TypedBinary(t, _, _, _) => Seq(t, t)
+ case TypeEnforcer(_, _, _) => Seq(TypeAny)
+ case p @ PassNodeType(_, _) => Seq(constraint.asInstanceOf[NodeType])
+ case _ => throw new IllegalStateException()
+ }
+
+ override def toString: String = "NodeTypeDef"
+ }
+
+ trait NodeType extends Property[TestNode] {
+ override def definition(): PropertyDef[TestNode, NodeType] = NodeTypeDef
+ override def toString: String = getClass.getSimpleName
+ }
+
+ object TypeAny extends NodeType {
+ override def satisfies(other: Property[TestNode]): Boolean = other match {
+ case TypeAny => true
+ case _: NodeType => false
+ case _ => throw new IllegalStateException()
+ }
+ }
+
+ object TypeA extends NodeType {
+ override def satisfies(other: Property[TestNode]): Boolean = other match {
+ case TypeA => true
+ case TypeAny => true
+ case _: NodeType => false
+ case _ => throw new IllegalStateException()
+ }
+ }
+
+ object TypeB extends NodeType {
+ override def satisfies(other: Property[TestNode]): Boolean = other match {
+ case TypeB => true
+ case TypeAny => true
+ case _: NodeType => false
+ case _ => throw new IllegalStateException()
+ }
+ }
+
+ object TypeC extends NodeType {
+ override def satisfies(other: Property[TestNode]): Boolean = other match {
+ case TypeC => true
+ case TypeAny => true
+ case _: NodeType => false
+ case _ => throw new IllegalStateException()
+ }
+ }
+
+ object NodeTypePropertyModel extends PropertyModel[TestNode] {
+ override def propertyDefs: Seq[PropertyDef[TestNode, _ <: Property[TestNode]]] = Seq(
+ NodeTypeDef)
+
+ override def newEnforcerRuleFactory(propertyDef: PropertyDef[TestNode, _ <: Property[TestNode]])
+ : EnforcerRuleFactory[TestNode] = {
+ (constraint: Property[TestNode]) =>
+ {
+ List(NodeTypeEnforcerRule(constraint.asInstanceOf[NodeType]))
+ }
+ }
+ }
+
+ object NodeTypePropertyModelWithOutEnforcerRules extends PropertyModel[TestNode] {
+ override def propertyDefs: Seq[PropertyDef[TestNode, _ <: Property[TestNode]]] = Seq(
+ NodeTypeDef)
+
+ override def newEnforcerRuleFactory(propertyDef: PropertyDef[TestNode, _ <: Property[TestNode]])
+ : EnforcerRuleFactory[TestNode] = (_: Property[TestNode]) => List.empty
+ }
+}
diff --git a/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/CboSuite.scala b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/CboSuite.scala
new file mode 100644
index 000000000000..a930ac356d97
--- /dev/null
+++ b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/CboSuite.scala
@@ -0,0 +1,468 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo
+
+import io.glutenproject.cbo.CboConfig.PlannerType
+import io.glutenproject.cbo.memo.Memo
+import io.glutenproject.cbo.rule.{CboRule, Shape, Shapes}
+
+import org.scalatest.funsuite.AnyFunSuite
+
+class ExhaustivePlannerCboSuite extends CboSuite {
+ override protected def conf: CboConfig = CboConfig(plannerType = PlannerType.Exhaustive)
+}
+
+class DpPlannerCboSuite extends CboSuite {
+ override protected def conf: CboConfig = CboConfig(plannerType = PlannerType.Dp)
+}
+
+abstract class CboSuite extends AnyFunSuite {
+ import CboSuite._
+
+ protected def conf: CboConfig
+
+ test("Group memo - re-memorize") {
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.none())
+ .withNewConfig(_ => conf)
+ val memo = Memo(cbo)
+ val group1 = memo.memorize(cbo, Unary(50, Unary(50, Leaf(30))))
+ val group2 = memo.memorize(cbo, Unary(50, Unary(50, Leaf(30))))
+ assert(group2 eq group1)
+ }
+
+ test("Group memo - define equivalence") {
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.none())
+ .withNewConfig(_ => conf)
+ val memo = Memo(cbo)
+ val group = memo.memorize(cbo, Unary(50, Unary(50, Leaf(30))))
+ val state = memo.newState()
+ assert(group.nodes(state).size == 1)
+ val can = group.nodes(state).head.asCanonical()
+ memo.openFor(can).memorize(cbo, Unary(30, Leaf(90)))
+ assert(memo.newState().allGroups().size == 4)
+ }
+
+ test("Group memo - define equivalence: binary with similar children, 1") {
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.none())
+ .withNewConfig(_ => conf)
+ val memo = Memo(cbo)
+ val group = memo.memorize(cbo, Binary(50, Leaf(30), Leaf(40)))
+ val state = memo.newState()
+ assert(group.nodes(state).size == 1)
+ val leaf40Group = memo.memorize(cbo, Leaf(40))
+ assert(leaf40Group.nodes(state).size == 1)
+ val can = leaf40Group.nodes(state).head.asCanonical()
+ memo.openFor(can).memorize(cbo, Leaf(30))
+ assert(memo.newState().allGroups().size == 3)
+ }
+
+ test("Group memo - define equivalence: binary with similar children, 2") {
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.none())
+ .withNewConfig(_ => conf)
+ val memo = Memo(cbo)
+ val group = memo.memorize(cbo, Binary(50, Unary(20, Leaf(30)), Unary(20, Leaf(40))))
+ val state = memo.newState()
+ assert(group.nodes(state).size == 1)
+ val leaf40Group = memo.memorize(cbo, Leaf(40))
+ assert(leaf40Group.nodes(state).size == 1)
+ val can = leaf40Group.nodes(state).head.asCanonical()
+ memo.openFor(can).memorize(cbo, Leaf(30))
+ assert(memo.newState().allGroups().size == 5)
+ }
+
+ test("Group memo - partial canonical") {
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.none())
+ .withNewConfig(_ => conf)
+ val memo = Memo(cbo)
+ val group1 = memo.memorize(cbo, Unary(50, Unary(50, Leaf(30))))
+ val group2 = memo.memorize(cbo, Unary(50, Group(1)))
+ assert(group2 eq group1)
+ }
+
+ test(s"Unary node") {
+ object DivideUnaryCost extends CboRule[TestNode] {
+ override def shift(node: TestNode): Iterable[TestNode] = node match {
+ case Unary(cost, child) =>
+ if (cost >= 35) {
+ val halfCost = cost / 3
+ List(Unary(halfCost, Unary(halfCost, child)))
+ } else {
+ List.empty
+ }
+ case other => List.empty
+ }
+
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(1)
+ }
+
+ object DecreaseUnaryCost extends CboRule[TestNode] {
+ override def shift(node: TestNode): Iterable[TestNode] = node match {
+ case Unary(cost, child) =>
+ if (cost >= 80) {
+ List(Unary(cost - 20, child))
+ } else {
+ List.empty
+ }
+ case other => List.empty
+ }
+
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(1)
+ }
+
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.reuse(List(DivideUnaryCost, DecreaseUnaryCost)))
+ .withNewConfig(_ => conf)
+ val plan = Unary(90, Leaf(70))
+ val planner = cbo.newPlanner(plan)
+ val optimized = planner.plan()
+
+ assert(optimized == Unary(23, Unary(23, Leaf(70))))
+ }
+
+ test(s"Unary node insertion") {
+ object InsertUnary2 extends CboRule[TestNode] {
+ override def shift(node: TestNode): Iterable[TestNode] = node match {
+ case Unary(cost1, Unary(cost2, child)) =>
+ List(Unary(cost1 - 11, Unary2(10, Unary(cost2, child))))
+ case other => List.empty
+ }
+
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(2)
+ }
+
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.reuse(List(InsertUnary2)))
+ .withNewConfig(_ => conf)
+
+ val plan = Unary(90, Unary(90, Leaf(70)))
+ val planner = cbo.newPlanner(plan)
+ val optimized = planner.plan()
+
+ assert(optimized == Unary(79, Unary2(10, Unary(90, Leaf(70)))))
+ }
+
+ test(s"Binary node") {
+ object DivideBinaryCost extends CboRule[TestNode] {
+ override def shift(node: TestNode): Iterable[TestNode] = node match {
+ case Binary(cost, left, right) =>
+ if (cost >= 35) {
+ val halfCost = cost / 3
+ List(Binary(halfCost, Binary(halfCost, left, right), Binary(halfCost, left, right)))
+ } else {
+ List.empty
+ }
+ case other => List.empty
+ }
+
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(1)
+ }
+
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.reuse(List(DivideBinaryCost)))
+ .withNewConfig(_ => conf)
+
+ val plan = Binary(90, Leaf(70), Leaf(70))
+ val planner = cbo.newPlanner(plan)
+ val optimized = planner.plan()
+
+ assert(optimized == Binary(90, Leaf(70), Leaf(70)))
+ }
+
+ test(s"Symmetric rule") {
+ object SymmetricRule extends CboRule[TestNode] {
+ override def shift(node: TestNode): Iterable[TestNode] = node match {
+ case Unary(cost, child) => List(Unary2(cost, child))
+ case Unary2(cost, child) => List(Unary(cost, child))
+ case other => List.empty
+ }
+
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(1)
+ }
+
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.reuse(List(SymmetricRule)))
+ .withNewConfig(_ => conf)
+
+ val plan = Unary(90, Leaf(70))
+ val planner = cbo.newPlanner(plan)
+ val optimized = planner.plan()
+ val state = planner.newState()
+
+ // The 2 plans have same cost
+ assert(optimized == Unary(90, Leaf(70)) || optimized == Unary2(90, Leaf(70)))
+ assert(state.memoState().getGroupCount() == 2)
+ }
+
+ test(s"Binary swap") {
+ object BinarySwap extends CboRule[TestNode] {
+ override def shift(node: TestNode): Iterable[TestNode] = node match {
+ case Binary(cost, left, right) if cost >= 1 =>
+ List(Binary((cost - 1).max(0), right, left))
+ case other => List.empty
+ }
+
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(1)
+ }
+
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.reuse(List(BinarySwap)))
+ .withNewConfig(_ => conf)
+
+ val plan = Binary(90, Leaf(50), Leaf(70))
+ val planner = cbo.newPlanner(plan)
+ val optimized = planner.plan()
+
+ assert(optimized == Binary(0, Leaf(50), Leaf(70)))
+ }
+
+ test(s"Binary swap equivalent leaves") {
+ object BinarySwap extends CboRule[TestNode] {
+ override def shift(node: TestNode): Iterable[TestNode] = node match {
+ case Binary(cost, left, right) if cost >= 1 =>
+ List(Binary((cost - 1).max(0), right, left))
+ case other => List.empty
+ }
+
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(1)
+ }
+
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.reuse(List(BinarySwap)))
+ .withNewConfig(_ => conf)
+
+ val plan = Binary(70, Binary(90, Leaf(50), Leaf(50)), Leaf(50))
+ val planner = cbo.newPlanner(plan)
+ val optimized = planner.plan()
+
+ assert(optimized == Binary(0, Binary(0, Leaf(50), Leaf(50)), Leaf(50)))
+ }
+
+ test(s"Avoid unused groups") {
+ object Unary2Unary3 extends CboRule[TestNode] {
+ override def shift(node: TestNode): Iterable[TestNode] = node match {
+ case Unary2(cost, child) if cost >= 1 => List(Unary3(cost - 1, child))
+ case Unary3(cost, child) if cost >= 1 => List(Unary2(cost + 1, child))
+ case other => List.empty
+ }
+
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(1)
+ }
+
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.reuse(List(Unary2Unary3)))
+ .withNewConfig(_ => conf)
+
+ val plan = Unary(50, Unary2(50, Leaf(30)))
+ val planner = cbo.newPlanner(plan)
+ val optimized = planner.plan()
+ val state = planner.newState()
+
+ assert(state.memoState().getGroupCount() == 3)
+ assert(optimized == Unary(50, Unary3(49, Leaf(30))))
+ }
+
+ test(s"Rule application depth - depth 1") {
+ val l2l2 = new LeafToLeaf2()
+ val u2u2 = new UnaryToUnary2()
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.reuse(List(l2l2, u2u2)))
+ .withNewConfig(_ => conf)
+
+ val plan = Unary(50, Unary2(50, Unary2(50, Unary2(50, Leaf(30)))))
+ val planner = cbo.newPlanner(plan)
+ val optimized = planner.plan()
+
+ assert(optimized == Unary2(49, Unary2(50, Unary2(50, Unary2(50, Leaf2(29))))))
+ }
+
+ test(s"Rule application depth - depth 2") {
+ val l2l2 = new LeafToLeaf2()
+ val u2u2 = new UnaryToUnary2()
+
+ object Unary2Unary2ToUnary3 extends CboRule[TestNode] {
+ var invocationCount: Int = 0
+ var effectiveInvocationCount: Int = 0
+ override def shift(node: TestNode): Iterable[TestNode] = {
+ invocationCount += 1
+ node match {
+ case Unary2(cost1, Unary2(cost2, child)) =>
+ effectiveInvocationCount += 1
+ List(Unary3(cost1 + cost2 - 1, child))
+ case other => List.empty
+ }
+ }
+
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(2)
+ }
+
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.reuse(List(l2l2, u2u2, Unary2Unary2ToUnary3)))
+ .withNewConfig(_ => conf)
+
+ val plan = Unary(50, Unary2(50, Unary2(50, Unary2(50, Leaf(30)))))
+ val planner = cbo.newPlanner(plan)
+ val optimized = planner.plan()
+
+ assert(Unary2Unary2ToUnary3.invocationCount == 14)
+ assert(Unary2Unary2ToUnary3.effectiveInvocationCount == 3)
+ assert(optimized == Unary3(98, Unary3(99, Leaf2(29))))
+ }
+}
+
+object CboSuite extends CboSuiteBase {
+
+ case class Binary(
+ override val selfCost: Long,
+ override val left: TestNode,
+ override val right: TestNode)
+ extends BinaryLike {
+ override def withNewChildren(left: TestNode, right: TestNode): BinaryLike =
+ copy(left = left, right = right)
+ }
+
+ case class Unary(override val selfCost: Long, override val child: TestNode) extends UnaryLike {
+ override def withNewChildren(child: TestNode): UnaryLike = copy(child = child)
+ }
+
+ case class Unary2(override val selfCost: Long, override val child: TestNode) extends UnaryLike {
+ override def withNewChildren(child: TestNode): UnaryLike = copy(child = child)
+ }
+
+ case class Unary3(override val selfCost: Long, override val child: TestNode) extends UnaryLike {
+ override def withNewChildren(child: TestNode): UnaryLike = copy(child = child)
+ }
+
+ case class Leaf(override val selfCost: Long) extends LeafLike {
+ override def makeCopy(): LeafLike = copy()
+ }
+
+ case class Leaf2(override val selfCost: Long) extends LeafLike {
+ override def makeCopy(): LeafLike = copy()
+ }
+
+ class LeafToLeaf2 extends CboRule[TestNode] {
+
+ override def shift(node: TestNode): Iterable[TestNode] = {
+ node match {
+ case Leaf(cost) =>
+ List(Leaf2(cost - 1))
+ case other => List.empty
+ }
+ }
+
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(1)
+ }
+
+ class UnaryToUnary2 extends CboRule[TestNode] {
+ override def shift(node: TestNode): Iterable[TestNode] = {
+ node match {
+ case Unary(cost, child) =>
+ List(Unary2(cost - 1, child))
+ case other => List.empty
+ }
+ }
+
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(1)
+ }
+
+ class Unary2ToUnary3 extends CboRule[TestNode] {
+ override def shift(node: TestNode): Iterable[TestNode] = {
+ node match {
+ case Unary2(cost, child) =>
+ List(Unary3(cost - 1, child))
+ case other => List.empty
+ }
+ }
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(1)
+ }
+
+}
diff --git a/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/CboSuiteBase.scala b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/CboSuiteBase.scala
new file mode 100644
index 000000000000..d718682f64cf
--- /dev/null
+++ b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/CboSuiteBase.scala
@@ -0,0 +1,195 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo
+
+import io.glutenproject.cbo.memo.{MemoLike, MemoState}
+import io.glutenproject.cbo.path.{CboPath, PathFinder}
+import io.glutenproject.cbo.property.PropertySet
+
+trait CboSuiteBase {
+ trait TestNode {
+ def selfCost(): Long
+ def children(): Seq[TestNode]
+ def withNewChildren(children: Seq[TestNode]): TestNode
+ }
+
+ trait UnaryLike extends TestNode {
+ def child: TestNode
+ def withNewChildren(child: TestNode): UnaryLike
+ def children(): Seq[TestNode] = List(child)
+ def withNewChildren(children: Seq[TestNode]): TestNode = withNewChildren(children.head)
+ }
+
+ trait BinaryLike extends TestNode {
+ def left: TestNode
+ def right: TestNode
+ def children(): Seq[TestNode] = List(left, right)
+ def withNewChildren(left: TestNode, right: TestNode): BinaryLike
+ def withNewChildren(children: Seq[TestNode]): TestNode =
+ withNewChildren(children.head, children(1))
+ }
+
+ trait LeafLike extends TestNode {
+ def makeCopy(): LeafLike
+ def children(): Seq[TestNode] = List.empty
+ def withNewChildren(children: Seq[TestNode]): TestNode = this
+ }
+
+ case class Group(id: Int, propSet: PropertySet[TestNode]) extends LeafLike {
+ override def selfCost(): Long = Long.MaxValue
+ override def makeCopy(): LeafLike = copy()
+ }
+
+ object Group {
+ def apply(id: Int): Group = {
+ Group(id, PropertySet(List.empty))
+ }
+ }
+
+ case class LongCost(value: Long) extends Cost
+
+ object CostModelImpl extends CostModel[TestNode] {
+
+ override def costComparator(): Ordering[Cost] = {
+ Ordering.Long.on { case LongCost(value) => value }
+ }
+
+ private def longCostOf(node: TestNode): Long = node match {
+ case n: TestNode =>
+ val selfCost = n.selfCost()
+
+ // Sum with ceil to avoid overflow.
+ def safeSum(a: Long, b: Long): Long = {
+ val sum = a + b
+ if (sum < a || sum < b) Long.MaxValue else sum
+ }
+
+ (n.children().map(longCostOf).toSeq :+ selfCost).reduce(safeSum)
+ case _ => throw new IllegalStateException()
+ }
+
+ override def costOf(node: TestNode): Cost = node match {
+ case g: Group => throw new IllegalStateException()
+ case n => LongCost(longCostOf(n))
+ }
+
+ override def makeInfCost(): Cost = LongCost(Long.MaxValue)
+ }
+
+ object PlanModelImpl extends PlanModel[TestNode] {
+ override def childrenOf(node: TestNode): Seq[TestNode] = node match {
+ case n: TestNode => n.children()
+ case _ => throw new IllegalStateException()
+ }
+
+ override def withNewChildren(node: TestNode, children: Seq[TestNode]): TestNode =
+ node match {
+ case n: TestNode => n.withNewChildren(children)
+ case _ => throw new IllegalStateException()
+ }
+
+ override def hashCode(node: TestNode): Int = {
+ java.util.Objects.hashCode(node)
+ }
+
+ override def equals(one: TestNode, other: TestNode): Boolean = {
+ java.util.Objects.equals(one, other)
+ }
+
+ override def newGroupLeaf(groupId: Int, propSet: PropertySet[TestNode]): TestNode =
+ Group(groupId, propSet)
+
+ override def getGroupId(node: TestNode): Int = node match {
+ case ngl: Group => ngl.id
+ case _ => throw new IllegalStateException()
+ }
+
+ override def isGroupLeaf(node: TestNode): Boolean = node match {
+ case _: Group => true
+ case _ => false
+ }
+ }
+
+ object ExplainImpl extends CboExplain[TestNode] {
+ override def describeNode(node: TestNode): String = node match {
+ case n => n.toString
+ }
+ }
+
+ object PropertyModelImpl extends PropertyModel[TestNode] {
+ override def propertyDefs: Seq[PropertyDef[TestNode, _ <: Property[TestNode]]] = List.empty
+ override def newEnforcerRuleFactory(propertyDef: PropertyDef[TestNode, _ <: Property[TestNode]])
+ : EnforcerRuleFactory[TestNode] = (_: Property[TestNode]) => List.empty
+ }
+
+ implicit class GraphvizPrinter[T <: AnyRef](val planner: CboPlanner[T]) {
+ def printGraphviz(): Unit = {
+ // scalastyle:off println
+ println(planner.newState().formatGraphviz())
+ // scalastyle:on println
+ }
+ }
+
+ implicit class MemoLikeImplicits[T <: AnyRef](val memo: MemoLike[T]) {
+ def memorize(cbo: Cbo[T], node: T): CboGroup[T] = {
+ memo.memorize(node, cbo.propSetsOf(node))
+ }
+ }
+
+ implicit class MemoStateImplicits[T <: AnyRef](val state: MemoState[T]) {
+ def getGroupCount(): Int = {
+ state.allGroups().size
+ }
+
+ def collectAllPaths(depth: Int): Iterable[CboPath[T]] = {
+ val allClusters = state.allClusters()
+ val allGroups = state.allGroups()
+
+ val highestFinder = PathFinder
+ .builder(state.cbo(), state)
+ .depth(depth)
+ .build()
+
+ allClusters
+ .flatMap(c => c.nodes())
+ .flatMap(
+ node => {
+ val highest = highestFinder.find(node).maxBy(c => c.height())
+ val finder = (1 to highest.height())
+ .foldLeft(PathFinder
+ .builder(state.cbo(), state)) {
+ case (builder, d) =>
+ builder.depth(d)
+ }
+ .build()
+ finder.find(node)
+ })
+ }
+ }
+
+ implicit class TestNodeImplicits(val node: TestNode) {
+ def asCanonical(cbo: Cbo[TestNode]): CanonicalNode[TestNode] = {
+ CanonicalNode(cbo, node)
+ }
+ }
+
+ implicit class TestNodeGroupImplicits(val group: CboGroup[TestNode]) {
+ def asGroup(cbo: Cbo[TestNode]): GroupNode[TestNode] = {
+ GroupNode(cbo, group)
+ }
+ }
+}
diff --git a/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/mock/MockCboPath.scala b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/mock/MockCboPath.scala
new file mode 100644
index 000000000000..f9bdc7b00065
--- /dev/null
+++ b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/mock/MockCboPath.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo.mock
+
+import io.glutenproject.cbo.{CanonicalNode, Cbo}
+import io.glutenproject.cbo.memo.Memo
+import io.glutenproject.cbo.path.{CboPath, PathKeySet}
+
+object MockCboPath {
+ def mock[T <: AnyRef](cbo: Cbo[T], node: T): CboPath[T] = {
+ mock(cbo, node, PathKeySet.trivial)
+ }
+
+ def mock[T <: AnyRef](cbo: Cbo[T], node: T, keys: PathKeySet): CboPath[T] = {
+ val memo = Memo(cbo)
+ val g = memo.memorize(node, cbo.propSetsOf(node))
+ val state = memo.newState()
+ val groupSupplier = state.asGroupSupplier()
+ assert(g.nodes(state).size == 1)
+ val can = g.nodes(state).head
+
+ def dfs(n: CanonicalNode[T]): CboPath[T] = {
+ if (n.isLeaf()) {
+ return CboPath.one(cbo, keys, groupSupplier, n)
+ }
+ CboPath(
+ cbo,
+ n,
+ n.getChildrenGroups(groupSupplier).map(_.group(groupSupplier)).map {
+ cg =>
+ assert(cg.nodes(state).size == 1)
+ dfs(cg.nodes(state).head)
+ }).get
+ }
+ dfs(can)
+ }
+}
diff --git a/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/mock/MockMemoState.scala b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/mock/MockMemoState.scala
new file mode 100644
index 000000000000..6e95dceb6524
--- /dev/null
+++ b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/mock/MockMemoState.scala
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo.mock
+
+import io.glutenproject.cbo._
+import io.glutenproject.cbo.memo.{MemoState, MemoStore}
+import io.glutenproject.cbo.property.PropertySet
+import io.glutenproject.cbo.vis.GraphvizVisualizer
+
+import scala.collection.mutable
+
+case class MockMemoState[T <: AnyRef] private (
+ override val cbo: Cbo[T],
+ override val clusterLookup: Map[CboClusterKey, CboCluster[T]],
+ override val allGroups: Seq[CboGroup[T]])
+ extends MemoState[T] {
+ def printGraphviz(group: CboGroup[T]): Unit = {
+ val graph = GraphvizVisualizer(cbo, this, group.id())
+ // scalastyle:off println
+ println(graph.format())
+ // scalastyle:on println
+ }
+
+ def printGraphviz(best: Best[T]): Unit = {
+ val graph = vis.GraphvizVisualizer(cbo, this, best)
+ // scalastyle:off println
+ println(graph.format())
+ // scalastyle:on println
+ }
+
+ override def allClusters(): Iterable[CboCluster[T]] = clusterLookup.values
+
+ override def getCluster(key: CboClusterKey): CboCluster[T] = clusterLookup(key)
+
+ override def getGroup(id: Int): CboGroup[T] = allGroups(id)
+}
+
+object MockMemoState {
+ class Builder[T <: AnyRef] private (cbo: Cbo[T]) {
+ private var propSet: PropertySet[T] = PropertySet[T](List.empty)
+ private val clusterBuffer = mutable.Map[CboClusterKey, MockMutableCluster[T]]()
+ private val groupFactory: MockMutableGroup.Factory[T] =
+ MockMutableGroup.Factory.create[T](cbo, propSet)
+
+ def withPropertySet(propSet: PropertySet[T]): Builder[T] = {
+ this.propSet = propSet
+ this
+ }
+
+ def newCluster(): MockMutableCluster[T] = {
+ val id = clusterBuffer.size
+ val key = MockMutableCluster.DummyIntClusterKey(id)
+ val cluster = MockMutableCluster[T](cbo, key, propSet, groupFactory)
+ clusterBuffer += (key -> cluster)
+ cluster
+ }
+
+ def build(): MockMemoState[T] = {
+ MockMemoState[T](cbo, clusterBuffer.toMap, groupFactory.allGroups())
+ }
+ }
+
+ object Builder {
+ def apply[T <: AnyRef](cbo: Cbo[T]): Builder[T] = {
+ new Builder[T](cbo)
+ }
+ }
+
+ // TODO add groups with different property sets
+ class MockMutableCluster[T <: AnyRef] private (
+ cbo: Cbo[T],
+ key: CboClusterKey,
+ groupFactory: MockMutableGroup.Factory[T])
+ extends CboCluster[T] {
+ private val nodeBuffer = mutable.ArrayBuffer[CanonicalNode[T]]()
+
+ def newGroup(): MockMutableGroup[T] = {
+ groupFactory.newGroup(key)
+ }
+
+ def addNodes(nodes: Seq[CanonicalNode[T]]): Unit = {
+ nodeBuffer ++= nodes
+ }
+
+ override def nodes(): Seq[CanonicalNode[T]] = nodeBuffer
+ }
+
+ object MockMutableCluster {
+ def apply[T <: AnyRef](
+ cbo: Cbo[T],
+ key: CboClusterKey,
+ propSet: PropertySet[T],
+ groupFactory: MockMutableGroup.Factory[T]): MockMutableCluster[T] = {
+ new MockMutableCluster[T](cbo, key, groupFactory)
+ }
+
+ case class DummyIntClusterKey(id: Int) extends CboClusterKey
+ }
+
+ class MockMutableGroup[T <: AnyRef] private (
+ override val id: Int,
+ override val clusterKey: CboClusterKey,
+ override val propSet: PropertySet[T],
+ override val self: T)
+ extends CboGroup[T] {
+ private val nodes: mutable.ArrayBuffer[CanonicalNode[T]] = mutable.ArrayBuffer()
+
+ def add(node: CanonicalNode[T]): Unit = {
+ nodes += node
+ }
+
+ def add(newNodes: Seq[CanonicalNode[T]]): Unit = {
+ nodes ++= newNodes
+ }
+
+ override def nodes(store: MemoStore[T]): Iterable[CanonicalNode[T]] = nodes
+ }
+
+ object MockMutableGroup {
+ class Factory[T <: AnyRef] private (cbo: Cbo[T], propSet: PropertySet[T]) {
+ private val groupBuffer = mutable.ArrayBuffer[MockMutableGroup[T]]()
+
+ def newGroup(clusterKey: CboClusterKey): MockMutableGroup[T] = {
+ val id = groupBuffer.size
+ val group =
+ new MockMutableGroup[T](id, clusterKey, propSet, cbo.planModel.newGroupLeaf(id, propSet))
+ groupBuffer += group
+ group
+ }
+
+ def allGroups(): Seq[MockMutableGroup[T]] = groupBuffer
+ }
+
+ object Factory {
+ def create[T <: AnyRef](cbo: Cbo[T], propSet: PropertySet[T]): Factory[T] = {
+ new Factory[T](cbo, propSet)
+ }
+ }
+ }
+
+}
diff --git a/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/path/CboPathSuite.scala b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/path/CboPathSuite.scala
new file mode 100644
index 000000000000..fff5838741fd
--- /dev/null
+++ b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/path/CboPathSuite.scala
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo.path
+
+import io.glutenproject.cbo.{Cbo, CboSuiteBase}
+import io.glutenproject.cbo.mock.MockCboPath
+import io.glutenproject.cbo.rule.CboRule
+
+import org.scalatest.funsuite.AnyFunSuite
+
+class CboPathSuite extends AnyFunSuite {
+ import CboPathSuite._
+
+ test("Path aggregate - empty") {
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.reuse(List.empty))
+ assert(CboPath.aggregate(cbo, List.empty) == List.empty)
+ }
+
+ test("Path aggregate") {
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.reuse(List.empty))
+
+ val n1 = "n1"
+ val n2 = "n2"
+ val n3 = "n3"
+ val n4 = "n4"
+ val n5 = "n5"
+ val n6 = "n6"
+ val path1 = MockCboPath.mock(
+ cbo,
+ Unary(n5, Leaf(n6, 1)),
+ PathKeySet(Set(DummyPathKey(1), DummyPathKey(3)))
+ )
+ val path2 = MockCboPath.mock(
+ cbo,
+ Unary(n1, Unary(n2, Leaf(n3, 1))),
+ PathKeySet(Set(DummyPathKey(1)))
+ )
+ val path3 = MockCboPath.mock(
+ cbo,
+ Unary(n1, Unary(n2, Leaf(n3, 1))),
+ PathKeySet(Set(DummyPathKey(1), DummyPathKey(2)))
+ )
+ val path4 = MockCboPath.mock(
+ cbo,
+ Unary(n1, Unary(n2, Leaf(n3, 1))),
+ PathKeySet(Set(DummyPathKey(4)))
+ )
+ val path5 = MockCboPath.mock(
+ cbo,
+ Unary(n5, Leaf(n6, 1)),
+ PathKeySet(Set(DummyPathKey(4)))
+ )
+ val out = CboPath
+ .aggregate(cbo, List(path1, path2, path3, path4, path5))
+ .toSeq
+ .sortBy(_.height())
+ assert(out.size == 2)
+ assert(out.head.height() == 2)
+ assert(out.head.plan() == Unary(n5, Leaf(n6, 1)))
+ assert(out.head.keys() == PathKeySet(Set(DummyPathKey(1), DummyPathKey(3), DummyPathKey(4))))
+
+ assert(out(1).height() == 3)
+ assert(out(1).plan() == Unary(n1, Unary(n2, Leaf(n3, 1))))
+ assert(out(1).keys() == PathKeySet(Set(DummyPathKey(1), DummyPathKey(2), DummyPathKey(4))))
+ }
+}
+
+object CboPathSuite extends CboSuiteBase {
+ case class Leaf(name: String, override val selfCost: Long) extends LeafLike {
+ override def makeCopy(): LeafLike = this
+ }
+
+ case class Unary(name: String, child: TestNode) extends UnaryLike {
+ override def selfCost(): Long = 1
+ override def withNewChildren(child: TestNode): UnaryLike = copy(child = child)
+ }
+
+ case class Binary(name: String, left: TestNode, right: TestNode) extends BinaryLike {
+ override def selfCost(): Long = 1
+ override def withNewChildren(left: TestNode, right: TestNode): BinaryLike =
+ copy(left = left, right = right)
+ }
+
+ case class DummyPathKey(value: Int) extends PathKey
+}
diff --git a/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/path/PathFinderSuite.scala b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/path/PathFinderSuite.scala
new file mode 100644
index 000000000000..a25cc2dda603
--- /dev/null
+++ b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/path/PathFinderSuite.scala
@@ -0,0 +1,306 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo.path
+
+import io.glutenproject.cbo.{CanonicalNode, Cbo, CboSuiteBase}
+import io.glutenproject.cbo.mock.MockMemoState
+import io.glutenproject.cbo.rule.CboRule
+
+import org.scalatest.funsuite.AnyFunSuite
+
+class PathFinderSuite extends AnyFunSuite {
+ import PathFinderSuite._
+
+ test("Base") {
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.none())
+
+ val mock = MockMemoState.Builder(cbo)
+ val cluster = mock.newCluster()
+ val groupA = cluster.newGroup()
+ val groupB = cluster.newGroup()
+ val groupC = cluster.newGroup()
+ val groupD = cluster.newGroup()
+ val groupE = cluster.newGroup()
+ val n1 = "n1"
+ val n2 = "n2"
+ val n3 = "n3"
+ val n4 = "n4"
+ val n5 = "n5"
+ val n6 = "n6"
+ val node1 = Binary(n1, groupB.self, groupC.self).asCanonical(cbo)
+ val node2 = Unary(n2, groupD.self).asCanonical(cbo)
+ val node3 = Unary(n3, groupE.self).asCanonical(cbo)
+ val node4 = Leaf(n4, 1).asCanonical(cbo)
+ val node5 = Leaf(n5, 1).asCanonical(cbo)
+ val node6 = Leaf(n6, 1).asCanonical(cbo)
+
+ groupA.add(node1)
+ groupB.add(node2)
+ groupC.add(node3)
+ groupD.add(node4)
+ groupE.add(List(node5, node6))
+
+ val state = mock.build()
+
+ def find(can: CanonicalNode[TestNode], depth: Int): Iterable[CboPath[TestNode]] = {
+ val finder = PathFinder.builder(cbo, state).depth(depth).build()
+ finder.find(can)
+ }
+
+ val height1 = find(node1, 1).map(_.plan()).toSeq
+ val height2 = find(node1, 2).map(_.plan()).toSeq
+ val heightInf = find(node1, CboPath.INF_DEPTH).map(_.plan()).toSeq
+
+ assert(height1 == List(Binary(n1, Group(1), Group(2))))
+ assert(height2 == List(Binary(n1, Unary(n2, Group(3)), Unary(n3, Group(4)))))
+ assert(
+ heightInf == List(
+ Binary(n1, Unary(n2, Leaf(n4, 1)), Unary(n3, Leaf(n5, 1))),
+ Binary(n1, Unary(n2, Leaf(n4, 1)), Unary(n3, Leaf(n6, 1)))))
+ }
+
+ test("Find - multiple depths") {
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.none())
+
+ val mock = MockMemoState.Builder(cbo)
+ val cluster = mock.newCluster()
+ val groupA = cluster.newGroup()
+ val groupB = cluster.newGroup()
+ val groupC = cluster.newGroup()
+ val groupD = cluster.newGroup()
+ val groupE = cluster.newGroup()
+
+ val n1 = "n1"
+ val n2 = "n2"
+ val n3 = "n3"
+ val n4 = "n4"
+ val n5 = "n5"
+ val n6 = "n6"
+ val node1 = Binary(n1, groupB.self, groupC.self).asCanonical(cbo)
+ val node2 = Unary(n2, groupD.self).asCanonical(cbo)
+ val node3 = Unary(n3, groupE.self).asCanonical(cbo)
+ val node4 = Leaf(n4, 1).asCanonical(cbo)
+ val node5 = Leaf(n5, 1).asCanonical(cbo)
+ val node6 = Leaf(n6, 1).asCanonical(cbo)
+
+ groupA.add(node1)
+ groupB.add(node2)
+ groupC.add(node3)
+ groupD.add(node4)
+ groupE.add(List(node5, node6))
+
+ val state = mock.build()
+
+ val finder1 = PathFinder
+ .builder(cbo, state)
+ .depth(1)
+ .depth(3)
+ .build()
+
+ assert(
+ finder1.find(node1).map(_.plan()).toSeq == List(
+ Binary(n1, Group(1), Group(2)),
+ Binary(n1, Unary(n2, Leaf(n4, 1)), Unary(n3, Leaf(n5, 1))),
+ Binary(n1, Unary(n2, Leaf(n4, 1)), Unary(n3, Leaf(n6, 1)))))
+
+ val finder2 = PathFinder
+ .builder(cbo, state)
+ .depth(2)
+ .depth(CboPath.INF_DEPTH)
+ .build()
+
+ assert(
+ finder2.find(node1).map(_.plan()).toSeq == List(
+ Binary(n1, Unary(n2, Group(3)), Unary(n3, Group(4))),
+ Binary(n1, Unary(n2, Leaf(n4, 1)), Unary(n3, Leaf(n5, 1))),
+ Binary(n1, Unary(n2, Leaf(n4, 1)), Unary(n3, Leaf(n6, 1)))
+ ))
+
+ val finder3 = PathFinder
+ .builder(cbo, state)
+ .depth(2)
+ .depth(2)
+ .depth(CboPath.INF_DEPTH)
+ .depth(CboPath.INF_DEPTH)
+ .build()
+
+ assert(
+ finder3.find(node1).map(_.plan()).toSeq == List(
+ Binary(n1, Unary(n2, Group(3)), Unary(n3, Group(4))),
+ Binary(n1, Unary(n2, Leaf(n4, 1)), Unary(n3, Leaf(n5, 1))),
+ Binary(n1, Unary(n2, Leaf(n4, 1)), Unary(n3, Leaf(n6, 1)))
+ ))
+ }
+
+ test("Dive - basic") {
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.none())
+
+ val mock = MockMemoState.Builder(cbo)
+ val cluster = mock.newCluster()
+ val groupA = cluster.newGroup()
+ val groupB = cluster.newGroup()
+ val groupC = cluster.newGroup()
+ val groupD = cluster.newGroup()
+ val groupE = cluster.newGroup()
+
+ val n1 = "n1"
+ val n2 = "n2"
+ val n3 = "n3"
+ val n4 = "n4"
+ val n5 = "n5"
+ val n6 = "n6"
+ val node1 = Binary(n1, groupB.self, groupC.self).asCanonical(cbo)
+ val node2 = Unary(n2, groupD.self).asCanonical(cbo)
+ val node3 = Unary(n3, groupE.self).asCanonical(cbo)
+ val node4 = Leaf(n4, 1).asCanonical(cbo)
+ val node5 = Leaf(n5, 1).asCanonical(cbo)
+ val node6 = Leaf(n6, 1).asCanonical(cbo)
+
+ groupA.add(node1)
+ groupB.add(node2)
+ groupC.add(node3)
+ groupD.add(node4)
+ groupE.add(List(node5, node6))
+
+ val state = mock.build()
+
+ val path = CboPath.one(cbo, PathKeySet.trivial, state.allGroups, node1)
+
+ assert(path.plan() == Binary(n1, Group(1), Group(2)))
+ assert(
+ path.dive(state, 1).map(_.plan()) == List(
+ Binary(n1, Unary(n2, Group(3)), Unary(n3, Group(4)))))
+ assert(
+ path.dive(state, 2).map(_.plan()) == List(
+ Binary(n1, Unary(n2, Leaf(n4, 1)), Unary(n3, Leaf(n5, 1))),
+ Binary(n1, Unary(n2, Leaf(n4, 1)), Unary(n3, Leaf(n6, 1)))))
+ assert(
+ path.dive(state, 3).map(_.plan()) == List(
+ Binary(n1, Unary(n2, Leaf(n4, 1)), Unary(n3, Leaf(n5, 1))),
+ Binary(n1, Unary(n2, Leaf(n4, 1)), Unary(n3, Leaf(n6, 1)))))
+ assert(
+ path.dive(state, CboPath.INF_DEPTH).map(_.plan()) == List(
+ Binary(n1, Unary(n2, Leaf(n4, 1)), Unary(n3, Leaf(n5, 1))),
+ Binary(n1, Unary(n2, Leaf(n4, 1)), Unary(n3, Leaf(n6, 1)))))
+ }
+
+ test("Find/Dive - binary with different children heights") {
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.none())
+
+ val mock = MockMemoState.Builder(cbo)
+ val cluster = mock.newCluster()
+ val groupA = cluster.newGroup()
+ val groupB = cluster.newGroup()
+ val groupC = cluster.newGroup()
+ val groupD = cluster.newGroup()
+ val groupE = cluster.newGroup()
+
+ val n1 = "n1"
+ val n2 = "n2"
+ val n3 = "n3"
+ val n4 = "n4"
+ val n5 = "n5"
+ val node1 = Binary(n1, groupB.self, groupC.self).asCanonical(cbo)
+ val node2 = Binary(n2, groupD.self, groupE.self).asCanonical(cbo)
+ val node3 = Leaf(n3, 1).asCanonical(cbo)
+ val node4 = Leaf(n4, 1).asCanonical(cbo)
+ val node5 = Leaf(n5, 1).asCanonical(cbo)
+
+ groupA.add(node1)
+ groupB.add(node2)
+ groupC.add(node3)
+ groupD.add(node4)
+ groupE.add(node5)
+
+ val state = mock.build()
+
+ def find(can: CanonicalNode[TestNode], depth: Int): Iterable[CboPath[TestNode]] = {
+ PathFinder.builder(cbo, state).depth(depth).build().find(can)
+ }
+
+ val height1 = find(node1, 1).map(_.plan()).toSeq
+ val height2 = find(node1, 2).map(_.plan()).toSeq
+ val height3 = find(node1, 3).map(_.plan()).toSeq
+ val height4 = find(node1, 4).map(_.plan()).toSeq
+ val heightInf = find(node1, CboPath.INF_DEPTH).map(_.plan()).toSeq
+
+ assert(height1 == List(Binary(n1, Group(1), Group(2))))
+ assert(height2 == List(Binary(n1, Binary(n2, Group(3), Group(4)), Leaf(n3, 1))))
+ assert(height3 == List(Binary(n1, Binary(n2, Leaf(n4, 1), Leaf(n5, 1)), Leaf(n3, 1))))
+ assert(height4 == List(Binary(n1, Binary(n2, Leaf(n4, 1), Leaf(n5, 1)), Leaf(n3, 1))))
+ assert(heightInf == List(Binary(n1, Binary(n2, Leaf(n4, 1), Leaf(n5, 1)), Leaf(n3, 1))))
+
+ val path = CboPath.one(cbo, PathKeySet.trivial, state.allGroups, node1)
+
+ assert(path.plan() == Binary(n1, Group(1), Group(2)))
+ assert(
+ path.dive(state, 1).map(_.plan()).toSeq == List(
+ Binary(n1, Binary(n2, Group(3), Group(4)), Leaf(n3, 1))))
+ assert(
+ path.dive(state, 2).map(_.plan()) == List(
+ Binary(n1, Binary(n2, Leaf(n4, 1), Leaf(n5, 1)), Leaf(n3, 1))))
+ assert(
+ path.dive(state, 3).map(_.plan()) == List(
+ Binary(n1, Binary(n2, Leaf(n4, 1), Leaf(n5, 1)), Leaf(n3, 1))))
+ assert(
+ path.dive(state, CboPath.INF_DEPTH).map(_.plan()) == List(
+ Binary(n1, Binary(n2, Leaf(n4, 1), Leaf(n5, 1)), Leaf(n3, 1))))
+ }
+}
+
+object PathFinderSuite extends CboSuiteBase {
+ case class Leaf(name: String, override val selfCost: Long) extends LeafLike {
+ override def makeCopy(): LeafLike = this
+ }
+
+ case class Unary(name: String, child: TestNode) extends UnaryLike {
+ override def selfCost(): Long = 1
+
+ override def withNewChildren(child: TestNode): UnaryLike = copy(child = child)
+ }
+
+ case class Binary(name: String, left: TestNode, right: TestNode) extends BinaryLike {
+ override def selfCost(): Long = 1
+
+ override def withNewChildren(left: TestNode, right: TestNode): BinaryLike =
+ copy(left = left, right = right)
+ }
+}
diff --git a/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/path/PathMaskSuite.scala b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/path/PathMaskSuite.scala
new file mode 100644
index 000000000000..ef784e928ad9
--- /dev/null
+++ b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/path/PathMaskSuite.scala
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo.path
+
+import io.glutenproject.cbo.CboSuiteBase
+
+import org.scalatest.funsuite.AnyFunSuite
+
+class PathMaskSuite extends AnyFunSuite {
+
+ test("Mask operation - fold 1") {
+ val in = PathMask(List(3, -1, 2, 0, 0, 0))
+ assert(in.fold(0) == PathMask(List(-1)))
+ assert(in.fold(1) == PathMask(List(3, -1, -1, -1)))
+ assert(in.fold(2) == PathMask(List(3, -1, 2, -1, -1, 0)))
+ assert(in.fold(3) == PathMask(List(3, -1, 2, 0, 0, 0)))
+ assert(in.fold(CboPath.INF_DEPTH) == PathMask(List(3, -1, 2, 0, 0, 0)))
+ }
+
+ test("Mask operation - fold 2") {
+ val in = PathMask(List(3, 2, 0, -1, 0, 1, 0))
+ assert(in.fold(0) == PathMask(List(-1)))
+ assert(in.fold(1) == PathMask(List(3, -1, -1, -1)))
+ assert(in.fold(2) == PathMask(List(3, 2, -1, -1, 0, 1, -1)))
+ assert(in.fold(3) == PathMask(List(3, 2, 0, -1, 0, 1, 0)))
+ assert(in.fold(CboPath.INF_DEPTH) == PathMask(List(3, 2, 0, -1, 0, 1, 0)))
+ }
+
+ test("Mask operation - sub-mask") {
+ val in = PathMask(List(3, 2, 0, -1, 0, 1, 0))
+ assert(in.subMaskAt(0) == PathMask(List(3, 2, 0, -1, 0, 1, 0)))
+ assert(in.subMaskAt(1) == PathMask(List(2, 0, -1)))
+ assert(in.subMaskAt(2) == PathMask(List(0)))
+ assert(in.subMaskAt(3) == PathMask(List(-1)))
+ assert(in.subMaskAt(4) == PathMask(List(0)))
+ assert(in.subMaskAt(5) == PathMask(List(1, 0)))
+ assert(in.subMaskAt(6) == PathMask(List(0)))
+ }
+
+ test("Mask operation - union 1") {
+ val in1 = PathMask(List(3, 2, 0, -1, 0, 1, 0))
+ val in2 = PathMask(List(-1))
+ assert(PathMask.union(List(in1, in2)) == PathMask(List(-1)))
+ }
+
+ test("Mask operation - union 2") {
+ val in1 = PathMask(List(3, 2, 0, -1, 0, 1, 0))
+ val in2 = PathMask(List(3, -1, 0, -1))
+ assert(PathMask.union(List(in1, in2)) == PathMask(List(3, -1, 0, -1)))
+ }
+
+ test("Mask operation - union 3") {
+ val in1 = PathMask(List(3, 2, 0, -1, 0, 1, 0))
+ val in2 = PathMask(List(3, -1, 0, 1, 0))
+ assert(PathMask.union(List(in1, in2)) == PathMask(List(3, -1, 0, 1, 0)))
+ }
+
+ test("Mask operation - union 4") {
+ val in1 = PathMask(List(3, 2, 0, -1, 0, 1, 0))
+ val in2 = PathMask(List(3, 2, 0, 0, 0, 1, 0))
+ assert(PathMask.union(List(in1, in2)) == PathMask(List(3, 2, 0, -1, 0, 1, 0)))
+ }
+
+ test("Mask operation - union 5") {
+ val in1 = PathMask(List(3, 2, 0, 0, -1, 1, 0))
+ val in2 = PathMask(List(3, -1, 2, -1, 0, 1, 0))
+ assert(PathMask.union(List(in1, in2)) == PathMask(List(3, -1, -1, 1, 0)))
+ }
+
+ test("Mask operation - satisfaction 1") {
+ val m1 = PathMask(List(1, 0))
+ val m2 = PathMask(List(-1))
+ assert(m1.satisfies(m2))
+ assert(!m2.satisfies(m1))
+ }
+
+ test("Mask operation - satisfaction 2") {
+ val m1 = PathMask(List(3, 2, 0, -1, 0, 1, 0))
+ val m2 = PathMask(List(3, -1, -1, -1))
+ assert(m1.satisfies(m2))
+ assert(!m2.satisfies(m1))
+ }
+
+ test("Mask operation - satisfaction 3") {
+ val m1 = PathMask(List(3, 0, 0, 0))
+ val m2 = PathMask(List(3, 0, 0, 0))
+ assert(m1.satisfies(m2))
+ assert(m2.satisfies(m1))
+ }
+
+ test("Mask operation - satisfaction 4") {
+ val m1 = PathMask(List(1, -1))
+ val m2 = PathMask(List(0))
+ assert(!m1.satisfies(m2))
+ assert(!m2.satisfies(m1))
+ }
+}
+
+object PathMaskSuite extends CboSuiteBase {}
diff --git a/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/path/WizardSuite.scala b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/path/WizardSuite.scala
new file mode 100644
index 000000000000..3bcb6d55eb9c
--- /dev/null
+++ b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/path/WizardSuite.scala
@@ -0,0 +1,304 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo.path
+
+import io.glutenproject.cbo.{Cbo, CboSuiteBase}
+import io.glutenproject.cbo.mock.MockMemoState
+import io.glutenproject.cbo.rule.CboRule
+
+import org.scalatest.funsuite.AnyFunSuite
+
+class WizardSuite extends AnyFunSuite {
+ import WizardSuite._
+
+ test("None") {
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.none())
+
+ val mock = MockMemoState.Builder(cbo)
+ val cluster = mock.newCluster()
+ val groupA = cluster.newGroup()
+
+ val n1 = "n1"
+ val node1 = Leaf(n1, 1).asCanonical(cbo)
+
+ groupA.add(node1)
+
+ val state = mock.build()
+
+ val finder = PathFinder.builder(cbo, state).output(OutputWizards.none()).build()
+ assert(finder.find(node1).map(_.plan()).toSeq == List.empty)
+ }
+
+ test("Prune by maximum depth") {
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.none())
+
+ val mock = MockMemoState.Builder(cbo)
+ val cluster = mock.newCluster()
+ val groupA = cluster.newGroup()
+ val groupB = cluster.newGroup()
+ val groupC = cluster.newGroup()
+ val groupD = cluster.newGroup()
+ val groupE = cluster.newGroup()
+
+ val n1 = "n1"
+ val n2 = "n2"
+ val n3 = "n3"
+ val n4 = "n4"
+ val n5 = "n5"
+ val n6 = "n6"
+ val node1 = Binary(n1, groupB.self, groupC.self).asCanonical(cbo)
+ val node2 = Unary(n2, groupD.self).asCanonical(cbo)
+ val node3 = Unary(n3, groupE.self).asCanonical(cbo)
+ val node4 = Leaf(n4, 1).asCanonical(cbo)
+ val node5 = Leaf(n5, 1).asCanonical(cbo)
+ val node6 = Leaf(n6, 1).asCanonical(cbo)
+
+ groupA.add(node1)
+ groupB.add(node2)
+ groupC.add(node3)
+ groupD.add(node4)
+ groupE.add(List(node5, node6))
+
+ val state = mock.build()
+
+ def findWithMaxDepths(maxDepths: Seq[Int]): Seq[TestNode] = {
+ val builder = PathFinder.builder(cbo, state)
+ val finder = maxDepths
+ .foldLeft(builder) {
+ case (builder, d) =>
+ builder.depth(d)
+ }
+ .build()
+ finder.find(node1).map(_.plan()).toSeq
+ }
+
+ assert(findWithMaxDepths(List(1)) == List(Binary(n1, Group(1), Group(2))))
+ assert(findWithMaxDepths(List(2)) == List(Binary(n1, Unary(n2, Group(3)), Unary(n3, Group(4)))))
+ assert(
+ findWithMaxDepths(List(3)) == List(
+ Binary(n1, Unary(n2, Leaf(n4, 1)), Unary(n3, Leaf(n5, 1))),
+ Binary(n1, Unary(n2, Leaf(n4, 1)), Unary(n3, Leaf(n6, 1)))))
+ assert(
+ findWithMaxDepths(List(4)) == List(
+ Binary(n1, Unary(n2, Leaf(n4, 1)), Unary(n3, Leaf(n5, 1))),
+ Binary(n1, Unary(n2, Leaf(n4, 1)), Unary(n3, Leaf(n6, 1)))))
+ assert(
+ findWithMaxDepths(List(CboPath.INF_DEPTH)) == List(
+ Binary(n1, Unary(n2, Leaf(n4, 1)), Unary(n3, Leaf(n5, 1))),
+ Binary(n1, Unary(n2, Leaf(n4, 1)), Unary(n3, Leaf(n6, 1)))))
+
+ assert(
+ findWithMaxDepths(List(1, 2)) == List(
+ Binary(n1, Group(1), Group(2)),
+ Binary(n1, Unary(n2, Group(3)), Unary(n3, Group(4)))))
+ assert(
+ findWithMaxDepths(List(2, CboPath.INF_DEPTH)) == List(
+ Binary(n1, Unary(n2, Group(3)), Unary(n3, Group(4))),
+ Binary(n1, Unary(n2, Leaf(n4, 1)), Unary(n3, Leaf(n5, 1))),
+ Binary(n1, Unary(n2, Leaf(n4, 1)), Unary(n3, Leaf(n6, 1)))
+ ))
+ }
+
+ test("Prune by pattern") {
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.none())
+
+ val mock = MockMemoState.Builder(cbo)
+ val cluster = mock.newCluster()
+ val groupA = cluster.newGroup()
+ val groupB = cluster.newGroup()
+ val groupC = cluster.newGroup()
+ val groupD = cluster.newGroup()
+ val groupE = cluster.newGroup()
+
+ val n1 = "n1"
+ val n2 = "n2"
+ val n3 = "n3"
+ val n4 = "n4"
+ val n5 = "n5"
+ val n6 = "n6"
+ val node1 = Binary(n1, groupB.self, groupC.self).asCanonical(cbo)
+ val node2 = Unary(n2, groupD.self).asCanonical(cbo)
+ val node3 = Unary(n3, groupE.self).asCanonical(cbo)
+ val node4 = Leaf(n4, 1).asCanonical(cbo)
+ val node5 = Leaf(n5, 1).asCanonical(cbo)
+ val node6 = Leaf(n6, 1).asCanonical(cbo)
+
+ groupA.add(node1)
+ groupB.add(node2)
+ groupC.add(node3)
+ groupD.add(node4)
+ groupE.add(List(node5, node6))
+
+ val state = mock.build()
+
+ def findWithPatterns(patterns: Seq[Pattern[TestNode]]): Seq[TestNode] = {
+ val builder = PathFinder.builder(cbo, state)
+ val finder = patterns
+ .foldLeft(builder) {
+ case (builder, pattern) =>
+ builder.output(OutputWizards.withPattern(pattern))
+ }
+ .build()
+ finder.find(node1).map(_.plan()).toSeq
+ }
+
+ assert(
+ findWithPatterns(List(Pattern.any[TestNode].build())) == List(Binary(n1, Group(1), Group(2))))
+ assert(
+ findWithPatterns(
+ List(
+ Pattern
+ .node[TestNode](
+ _ => true,
+ Pattern.node(_ => true, Pattern.ignore),
+ Pattern.node(_ => true, Pattern.ignore))
+ .build())) == List(Binary(n1, Unary(n2, Group(3)), Unary(n3, Group(4)))))
+
+ // Pattern pruning should emit all results
+ val pattern1 = Pattern
+ .node[TestNode](_ => true, Pattern.node(_ => true, Pattern.ignore), Pattern.ignore)
+ .build()
+ val pattern2 = Pattern
+ .node[TestNode](_ => true, Pattern.ignore, Pattern.node(_ => true, Pattern.ignore))
+ .build()
+
+ assert(
+ findWithPatterns(List(pattern1, pattern2)) == List(
+ Binary(n1, Group(1), Unary(n3, Group(4))),
+ Binary(n1, Unary(n2, Group(3)), Group(2))))
+
+ // Distinguish between ignore and any
+ val pattern3 = Pattern
+ .node[TestNode](_ => true, Pattern.node(_ => true, Pattern.any), Pattern.ignore)
+ .build()
+ val pattern4 = Pattern
+ .node[TestNode](_ => true, Pattern.ignore, Pattern.node(_ => true, Pattern.any))
+ .build()
+
+ assert(
+ findWithPatterns(List(pattern3, pattern4)) == List(
+ Binary(n1, Group(1), Unary(n3, Leaf(n5, 1))),
+ Binary(n1, Group(1), Unary(n3, Leaf(n6, 1))),
+ Binary(n1, Unary(n2, Leaf(n4, 1)), Group(2))))
+
+ }
+
+ test("Prune by mask") {
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.none())
+
+ val mock = MockMemoState.Builder(cbo)
+ val cluster = mock.newCluster()
+ val groupA = cluster.newGroup()
+ val groupB = cluster.newGroup()
+ val groupC = cluster.newGroup()
+ val groupD = cluster.newGroup()
+ val groupE = cluster.newGroup()
+
+ val n1 = "n1"
+ val n2 = "n2"
+ val n3 = "n3"
+ val n4 = "n4"
+ val n5 = "n5"
+ val n6 = "n6"
+ val node1 = Binary(n1, groupB.self, groupC.self).asCanonical(cbo)
+ val node2 = Unary(n2, groupD.self).asCanonical(cbo)
+ val node3 = Unary(n3, groupE.self).asCanonical(cbo)
+ val node4 = Leaf(n4, 1).asCanonical(cbo)
+ val node5 = Leaf(n5, 1).asCanonical(cbo)
+ val node6 = Leaf(n6, 1).asCanonical(cbo)
+
+ groupA.add(node1)
+ groupB.add(node2)
+ groupC.add(node3)
+ groupD.add(node4)
+ groupE.add(List(node5, node6))
+
+ val state = mock.build()
+
+ def findWithMask(mask: Seq[Int]): Seq[TestNode] = {
+ PathFinder
+ .builder(cbo, state)
+ .output(OutputWizards.withMask(PathMask(mask)))
+ .build()
+ .find(node1)
+ .map(_.plan())
+ .toSeq
+ }
+
+ assert(findWithMask(List(2, -1, -1)) == List(Binary(n1, Group(1), Group(2))))
+ assert(findWithMask(List(2, 1, -1, -1)) == List(Binary(n1, Unary(n2, Group(3)), Group(2))))
+ assert(findWithMask(List(2, -1, 1, -1)) == List(Binary(n1, Group(1), Unary(n3, Group(4)))))
+ assert(
+ findWithMask(List(2, -1, 1, 0)) == List(
+ Binary(n1, Group(1), Unary(n3, Leaf(n5, 1))),
+ Binary(n1, Group(1), Unary(n3, Leaf(n6, 1)))))
+ assert(findWithMask(List(2, 1, 0, -1)) == List(Binary(n1, Unary(n2, Leaf(n4, 1)), Group(2))))
+ assert(
+ findWithMask(List(2, 1, -1, 1, 0)) == List(
+ Binary(n1, Unary(n2, Group(3)), Unary(n3, Leaf(n5, 1))),
+ Binary(n1, Unary(n2, Group(3)), Unary(n3, Leaf(n6, 1)))))
+ assert(
+ findWithMask(List(2, 1, 0, 1, -1)) ==
+ List(Binary(n1, Unary(n2, Leaf(n4, 1)), Unary(n3, Group(4)))))
+ assert(
+ findWithMask(List(2, 1, 0, 1, 0)) == List(
+ Binary(n1, Unary(n2, Leaf(n4, 1)), Unary(n3, Leaf(n5, 1))),
+ Binary(n1, Unary(n2, Leaf(n4, 1)), Unary(n3, Leaf(n6, 1)))))
+
+ }
+}
+
+object WizardSuite extends CboSuiteBase {
+ case class Leaf(name: String, override val selfCost: Long) extends LeafLike {
+ override def makeCopy(): LeafLike = this
+ }
+
+ case class Unary(name: String, child: TestNode) extends UnaryLike {
+ override def selfCost(): Long = 1
+ override def withNewChildren(child: TestNode): UnaryLike = copy(child = child)
+ }
+
+ case class Binary(name: String, left: TestNode, right: TestNode) extends BinaryLike {
+ override def selfCost(): Long = 1
+ override def withNewChildren(left: TestNode, right: TestNode): BinaryLike =
+ copy(left = left, right = right)
+ }
+}
diff --git a/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/rule/PatternSuite.scala b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/rule/PatternSuite.scala
new file mode 100644
index 000000000000..c45eea162809
--- /dev/null
+++ b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/rule/PatternSuite.scala
@@ -0,0 +1,200 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo.rule
+
+import io.glutenproject.cbo.{rule, Cbo, CboSuiteBase}
+import io.glutenproject.cbo.mock.MockCboPath
+import io.glutenproject.cbo.path.{CboPath, Pattern}
+
+import org.scalatest.funsuite.AnyFunSuite
+
+class PatternSuite extends AnyFunSuite {
+ import PatternSuite._
+ test("Match any") {
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.none())
+
+ val pattern = Pattern.ignore[TestNode].build()
+ val path = MockCboPath.mock(cbo, Leaf("n1", 1))
+ assert(path.height() == 1)
+
+ assert(pattern.matches(path, 1))
+ }
+
+ test("Match ignore") {
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.none())
+
+ val pattern = Pattern.ignore[TestNode].build()
+ val path = MockCboPath.mock(cbo, Leaf("n1", 1))
+ assert(path.height() == 1)
+
+ assert(pattern.matches(path, 1))
+ }
+
+ test("Match unary") {
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.none())
+
+ val path = MockCboPath.mock(cbo, Unary("n1", Leaf("n2", 1)))
+ assert(path.height() == 2)
+
+ val pattern1 = Pattern.node[TestNode](n => n.isInstanceOf[Unary], Pattern.ignore).build()
+ assert(pattern1.matches(path, 1))
+ assert(pattern1.matches(path, 2))
+
+ val pattern2 =
+ Pattern.node[TestNode](n => n.asInstanceOf[Unary].name == "foo", Pattern.ignore).build()
+ assert(!pattern2.matches(path, 1))
+ assert(!pattern2.matches(path, 2))
+ }
+
+ test("Match binary") {
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.none())
+
+ val path = MockCboPath.mock(
+ cbo,
+ Binary("n7", Unary("n1", Unary("n2", Leaf("n3", 1))), Unary("n5", Leaf("n6", 1))))
+ assert(path.height() == 4)
+
+ val pattern = Pattern
+ .node[TestNode](
+ n => n.isInstanceOf[Binary],
+ Pattern.node(
+ n => n.isInstanceOf[Unary],
+ Pattern.node(
+ n => n.isInstanceOf[Unary],
+ Pattern.ignore
+ )
+ ),
+ Pattern.ignore)
+ .build()
+ assert(pattern.matches(path, 1))
+ assert(pattern.matches(path, 2))
+ assert(pattern.matches(path, 3))
+ assert(pattern.matches(path, 4))
+ }
+
+ test("Matches above a certain depth") {
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.none())
+
+ val path = MockCboPath.mock(
+ cbo,
+ Binary("n7", Unary("n1", Unary("n2", Leaf("n3", 1))), Unary("n5", Leaf("n6", 1))))
+ assert(path.height() == 4)
+
+ val pattern1 = Pattern
+ .node[TestNode](
+ n => n.isInstanceOf[Binary],
+ Pattern.node(
+ n => n.isInstanceOf[Unary],
+ Pattern.node(
+ n => n.isInstanceOf[Unary],
+ Pattern.leaf(
+ _.asInstanceOf[Leaf].name == "foo"
+ )
+ )
+ ),
+ Pattern.ignore
+ )
+ .build()
+
+ assert(pattern1.matches(path, 1))
+ assert(pattern1.matches(path, 2))
+ assert(pattern1.matches(path, 3))
+ assert(!pattern1.matches(path, 4))
+
+ val pattern2 = Pattern
+ .node[TestNode](
+ n => n.isInstanceOf[Binary],
+ Pattern.node(
+ n => n.isInstanceOf[Unary],
+ Pattern.node(
+ n => n.isInstanceOf[Unary],
+ Pattern.node(
+ n => n.isInstanceOf[Unary],
+ Pattern.ignore
+ )
+ )
+ ),
+ Pattern.ignore
+ )
+ .build()
+
+ assert(pattern2.matches(path, 1))
+ assert(pattern2.matches(path, 2))
+ assert(pattern2.matches(path, 3))
+ assert(!pattern2.matches(path, 4))
+ }
+}
+
+object PatternSuite extends CboSuiteBase {
+ case class Leaf(name: String, override val selfCost: Long) extends LeafLike {
+ override def makeCopy(): LeafLike = this
+ }
+
+ case class Unary(name: String, child: TestNode) extends UnaryLike {
+ override def selfCost(): Long = 1
+
+ override def withNewChildren(child: TestNode): UnaryLike = copy(child = child)
+ }
+
+ case class Binary(name: String, left: TestNode, right: TestNode) extends BinaryLike {
+ override def selfCost(): Long = 1
+
+ override def withNewChildren(left: TestNode, right: TestNode): BinaryLike =
+ copy(left = left, right = right)
+ }
+
+ case class DummyGroup() extends LeafLike {
+ override def makeCopy(): rule.PatternSuite.LeafLike = throw new UnsupportedOperationException()
+ override def selfCost(): Long = throw new UnsupportedOperationException()
+ }
+
+ implicit class PatternImplicits[T <: AnyRef](pattern: Pattern[T]) {
+ def matchesAll(paths: Seq[CboPath[T]], depth: Int): Boolean = {
+ paths.forall(pattern.matches(_, depth))
+ }
+ }
+}
diff --git a/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/specific/CyclicSearchSpaceSuite.scala b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/specific/CyclicSearchSpaceSuite.scala
new file mode 100644
index 000000000000..80faa09db5ac
--- /dev/null
+++ b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/specific/CyclicSearchSpaceSuite.scala
@@ -0,0 +1,224 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo.specific
+
+import io.glutenproject.cbo._
+import io.glutenproject.cbo.best.BestFinder
+import io.glutenproject.cbo.memo.MemoState
+import io.glutenproject.cbo.mock.MockMemoState
+import io.glutenproject.cbo.path.{CboPath, PathFinder}
+import io.glutenproject.cbo.rule.CboRule
+
+import org.scalatest.funsuite.AnyFunSuite
+
+class GroupBastBestFinderCyclicSearchSpaceSuite extends CyclicSearchSpaceSuite {
+ override protected def newBestFinder[T <: AnyRef](
+ cbo: Cbo[T],
+ memoState: MemoState[T]): BestFinder[T] = BestFinder(cbo, memoState)
+}
+
+abstract class CyclicSearchSpaceSuite extends AnyFunSuite {
+ import CyclicSearchSpaceSuite._
+
+ protected def newBestFinder[T <: AnyRef](cbo: Cbo[T], memoState: MemoState[T]): BestFinder[T]
+
+ test("Cyclic - find paths, simple self cycle") {
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.none())
+
+ val mock = MockMemoState.Builder(cbo)
+ val cluster = mock.newCluster()
+
+ val groupA = cluster.newGroup()
+
+ val node1 = Unary("node1", groupA.self).asCanonical(cbo)
+ val node2 = Leaf("node2", 1).asCanonical(cbo)
+
+ groupA.add(List(node1, node2))
+
+ cluster.addNodes(List(node1, node2))
+
+ val mockState = mock.build()
+
+ def find(can: CanonicalNode[TestNode], depth: Int): Iterable[CboPath[TestNode]] = {
+ PathFinder.builder(cbo, mockState).depth(depth).build().find(can)
+ }
+
+ assert(find(node1, 1).map(p => p.plan()) == List(Unary("node1", Group(0))))
+ assert(find(node1, 2).map(p => p.plan()) == List(Unary("node1", Leaf("node2", 1))))
+ assert(find(node1, 3).map(p => p.plan()) == List(Unary("node1", Leaf("node2", 1))))
+ assert(
+ find(node1, CboPath.INF_DEPTH).map(p => p.plan()) == List(Unary("node1", Leaf("node2", 1))))
+ }
+
+ test("Cyclic - find best, simple self cycle") {
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.none())
+
+ val mock = MockMemoState.Builder(cbo)
+ val cluster = mock.newCluster()
+
+ val groupA = cluster.newGroup()
+
+ val node1 = Unary("node1", groupA.self).asCanonical(cbo)
+ val node2 = Leaf("node2", 1).asCanonical(cbo)
+
+ groupA.add(List(node1, node2))
+
+ cluster.addNodes(List(node1, node2))
+
+ val mockState = mock.build()
+ val bestFinder = newBestFinder(cbo, mockState)
+ val best = bestFinder.bestOf(groupA.id).path()
+ assert(best.cboPath.plan() == Leaf("node2", 1))
+ assert(best.cost == LongCost(1))
+ }
+
+ test("Cyclic - find best, case 1") {
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.none())
+
+ val mock = MockMemoState.Builder(cbo)
+ val cluster = mock.newCluster()
+
+ val groupA = cluster.newGroup()
+ val groupB = cluster.newGroup()
+ val groupC = cluster.newGroup()
+ val groupD = cluster.newGroup()
+ val groupE = cluster.newGroup()
+ val groupF = cluster.newGroup()
+ val groupG = cluster.newGroup()
+ val groupH = cluster.newGroup()
+
+ val node1 = Binary("node1", groupB.self, groupC.self).asCanonical(cbo)
+ val node2 = Unary("node2", groupF.self).asCanonical(cbo)
+ val node3 = Binary("node3", groupD.self, groupF.self).asCanonical(cbo)
+ val node4 = Binary("node4", groupG.self, groupH.self).asCanonical(cbo)
+ val node5 = Unary("node5", groupC.self).asCanonical(cbo)
+ val node6 = Unary("node6", groupE.self).asCanonical(cbo)
+ val node7 = Leaf("node7", 1).asCanonical(cbo)
+ val node8 = Leaf("node8", 1).asCanonical(cbo)
+ val node9 = Leaf("node9", 1).asCanonical(cbo)
+ // The best path should avoid including this node to most extent.
+ val node10 = Leaf("node10", 100).asCanonical(cbo)
+
+ groupA.add(node1)
+ groupB.add(node2)
+ groupC.add(List(node3, node4))
+ groupD.add(node9)
+ groupE.add(node5)
+ groupF.add(List(node6, node10))
+ groupG.add(node7)
+ groupH.add(node8)
+
+ cluster.addNodes(List(node1, node2, node3, node4, node5, node6, node7, node8, node9))
+
+ val mockState = mock.build()
+
+ val bestFinder = newBestFinder(cbo, mockState)
+
+ def assertBestOf(group: CboGroup[TestNode])(assertion: Best[TestNode] => Unit): Unit = {
+ val best = bestFinder.bestOf(group.id())
+ assertion(best)
+ }
+
+ assertBestOf(groupA)(best => assert(best.path().cost == LongCost(10)))
+ assertBestOf(groupB)(best => assert(best.path().cost == LongCost(6)))
+ assertBestOf(groupC)(best => assert(best.path().cost == LongCost(3)))
+ assertBestOf(groupD)(best => assert(best.path().cost == LongCost(1)))
+ assertBestOf(groupE)(best => assert(best.path().cost == LongCost(4)))
+ assertBestOf(groupF)(best => assert(best.path().cost == LongCost(5)))
+ assertBestOf(groupG)(best => assert(best.path().cost == LongCost(1)))
+ assertBestOf(groupH)(best => assert(best.path().cost == LongCost(1)))
+ }
+
+ test("Cyclic - find best, case 2") {
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.none())
+
+ val mock = MockMemoState.Builder(cbo)
+ val cluster = mock.newCluster()
+
+ val groupA = cluster.newGroup()
+ val groupB = cluster.newGroup()
+ val groupC = cluster.newGroup()
+ val groupD = cluster.newGroup()
+
+ val node1 = Unary("node1", groupB.self).asCanonical(cbo)
+ val node2 = Unary("node2", groupC.self).asCanonical(cbo)
+ val node3 = Unary("node3", groupC.self).asCanonical(cbo)
+ val node4 = Unary("node4", groupD.self).asCanonical(cbo)
+ val node5 = Unary("node5", groupB.self).asCanonical(cbo)
+ val node6 = Leaf("node6", 1).asCanonical(cbo)
+
+ groupA.add(node1)
+ groupA.add(node2)
+ groupB.add(node3)
+ groupB.add(node4)
+ groupC.add(node5)
+ groupD.add(node6)
+
+ cluster.addNodes(List(node1, node2, node3, node4, node5, node6))
+
+ val mockState = mock.build()
+
+ val bestFinder = newBestFinder(cbo, mockState)
+ val best = bestFinder.bestOf(groupA.id)
+
+ assert(best.costs()(InGroupNode(groupA.id, node1)).contains(LongCost(3)))
+ assert(best.costs()(InGroupNode(groupA.id, node2)).contains(LongCost(4)))
+ assert(best.costs()(InGroupNode(groupB.id, node3)).isEmpty)
+ assert(best.costs()(InGroupNode(groupB.id, node4)).contains(LongCost(2)))
+ assert(best.costs()(InGroupNode(groupC.id, node5)).contains(LongCost(3)))
+ assert(best.costs()(InGroupNode(groupD.id, node6)).contains(LongCost(1)))
+ }
+}
+
+object CyclicSearchSpaceSuite extends CboSuiteBase {
+ case class Leaf(name: String, override val selfCost: Long) extends LeafLike {
+ override def makeCopy(): LeafLike = this
+ }
+ case class Unary(name: String, child: TestNode) extends UnaryLike {
+ override def selfCost(): Long = 1
+ override def withNewChildren(child: TestNode): UnaryLike = copy(child = child)
+ }
+ case class Binary(name: String, left: TestNode, right: TestNode) extends BinaryLike {
+ override def selfCost(): Long = 1
+ override def withNewChildren(left: TestNode, right: TestNode): BinaryLike =
+ copy(left = left, right = right)
+ }
+}
diff --git a/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/specific/DistributedSuite.scala b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/specific/DistributedSuite.scala
new file mode 100644
index 000000000000..6910ebcc15d0
--- /dev/null
+++ b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/specific/DistributedSuite.scala
@@ -0,0 +1,508 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo.specific
+
+import io.glutenproject.cbo._
+import io.glutenproject.cbo.CboConfig.PlannerType
+import io.glutenproject.cbo.property.PropertySet
+import io.glutenproject.cbo.rule.{CboRule, Shape, Shapes}
+
+import org.scalatest.funsuite.AnyFunSuite
+
+class ExhaustivePlannerDistributedSuite extends DistributedSuite {
+ override protected def conf: CboConfig = CboConfig(plannerType = PlannerType.Exhaustive)
+}
+
+class DpPlannerDistributedSuite extends DistributedSuite {
+ override protected def conf: CboConfig = CboConfig(plannerType = PlannerType.Dp)
+}
+
+abstract class DistributedSuite extends AnyFunSuite {
+ import DistributedSuite._
+
+ protected def conf: CboConfig
+
+ test("Project - dry run") {
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ DistributedPropertyModel,
+ ExplainImpl,
+ CboRule.Factory.none())
+ .withNewConfig(_ => conf)
+
+ val plan = DProject(DLeaf())
+ val planner = cbo.newPlanner(plan, PropertySet(List(AnyDistribution, AnyOrdering)))
+ val out = planner.plan()
+ assert(out == DProject(DLeaf()))
+ }
+
+ test("Project - required distribution") {
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ DistributedPropertyModel,
+ ExplainImpl,
+ CboRule.Factory.none())
+ .withNewConfig(_ => conf)
+
+ val plan = DProject(DLeaf())
+ val planner =
+ cbo.newPlanner(plan, PropertySet(List(HashDistribution(List("a", "b")), AnyOrdering)))
+ val out = planner.plan()
+ assert(out == DProject(DExchange(List("a", "b"), DLeaf())))
+ }
+
+ test("Aggregate - none-distribution constraint") {
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ DistributedPropertyModel,
+ ExplainImpl,
+ CboRule.Factory.none())
+ .withNewConfig(_ => conf)
+
+ val plan = DAggregate(List("a", "b"), DLeaf())
+ val planner =
+ cbo.newPlanner(plan, PropertySet(List(HashDistribution(List("b", "c")), AnyOrdering)))
+ val out = planner.plan()
+ assert(
+ out == DExchange(
+ List("b", "c"),
+ DAggregate(List("a", "b"), DExchange(List("a", "b"), DLeaf()))))
+ }
+
+ test("Project - required ordering") {
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ DistributedPropertyModel,
+ ExplainImpl,
+ CboRule.Factory.none())
+ .withNewConfig(_ => conf)
+
+ val plan = DProject(DLeaf())
+ val planner =
+ cbo.newPlanner(plan, PropertySet(List(AnyDistribution, SimpleOrdering(List("a", "b")))))
+ val out = planner.plan()
+ assert(out == DProject(DSort(List("a", "b"), DLeaf())))
+ }
+
+ test("Project - required distribution and ordering") {
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ DistributedPropertyModel,
+ ExplainImpl,
+ CboRule.Factory.none())
+ .withNewConfig(_ => conf)
+
+ val plan = DProject(DLeaf())
+ val planner =
+ cbo.newPlanner(
+ plan,
+ PropertySet(List(HashDistribution(List("a", "b")), SimpleOrdering(List("b", "c")))))
+ val out = planner.plan()
+ assert(out == DProject(DSort(List("b", "c"), DExchange(List("a", "b"), DLeaf()))))
+ }
+
+ test("Aggregate - avoid re-exchange") {
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ DistributedPropertyModel,
+ ExplainImpl,
+ CboRule.Factory.none())
+ .withNewConfig(_ => conf)
+
+ val plan = DAggregate(List("a"), DProject(DAggregate(List("a", "b"), DLeaf())))
+ val planner = cbo.newPlanner(plan, PropertySet(List(AnyDistribution, AnyOrdering)))
+ val out = planner.plan()
+ assert(
+ out == DAggregate(
+ List("a"),
+ DProject(DAggregate(List("a", "b"), DExchange(List("a"), DLeaf())))))
+ }
+
+ test("Aggregate - avoid re-exchange, required ordering") {
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ DistributedPropertyModel,
+ ExplainImpl,
+ CboRule.Factory.none())
+ .withNewConfig(_ => conf)
+
+ val plan = DAggregate(List("a"), DProject(DAggregate(List("a", "b"), DLeaf())))
+ val planner =
+ cbo.newPlanner(plan, PropertySet(List(AnyDistribution, SimpleOrdering(List("b", "c")))))
+ val out = planner.plan()
+ assert(
+ out == DSort(
+ List("b", "c"),
+ DAggregate(List("a"), DProject(DAggregate(List("a", "b"), DExchange(List("a"), DLeaf()))))))
+ }
+
+ ignore("Aggregate - avoid re-exchange, partial") {
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ DistributedPropertyModel,
+ ExplainImpl,
+ CboRule.Factory.reuse(List(PartialAggregateRule)))
+ .withNewConfig(_ => conf)
+
+ val plan = DAggregate(List("a"), DProject(DAggregate(List("a", "b"), DLeaf())))
+ val planner = cbo.newPlanner(plan, PropertySet(List(AnyDistribution, AnyOrdering)))
+ val out = planner.plan()
+ // FIXME: Should push partial agg down through exchange, otherwise we'd have to write
+ // a rule for that
+ assert(
+ out == DFinalAggregate(
+ List("a"),
+ DPartialAggregate(
+ List("a"),
+ DProject(
+ DFinalAggregate(
+ List("a", "b"),
+ DExchange(List("a"), DPartialAggregate(List("a", "b"), DLeaf())))))))
+ }
+}
+
+object DistributedSuite extends CboSuiteBase {
+
+ object PartialAggregateRule extends CboRule[TestNode] {
+
+ override def shift(node: TestNode): Iterable[TestNode] = node match {
+ case DAggregate(keys, child) => List(DFinalAggregate(keys, DPartialAggregate(keys, child)))
+ case _ => List.empty
+ }
+
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(1)
+ }
+
+ trait Distribution extends Property[TestNode]
+
+ case class HashDistribution(keys: Seq[String]) extends Distribution {
+ override def satisfies(other: Property[TestNode]): Boolean = other match {
+ case HashDistribution(otherKeys) if keys.size > otherKeys.size => false
+ case HashDistribution(otherKeys) =>
+ // (a) satisfies (a, b)
+ keys.zipWithIndex.forall {
+ case (key, index) =>
+ key == otherKeys(index)
+ }
+ case AnyDistribution => true
+ case NoneDistribution => false
+ case _ => throw new UnsupportedOperationException()
+ }
+ override def definition(): PropertyDef[TestNode, _ <: Property[TestNode]] = DistributionDef
+ }
+
+ case object AnyDistribution extends Distribution {
+ override def satisfies(other: Property[TestNode]): Boolean = other match {
+ case HashDistribution(_) => false
+ case AnyDistribution => true
+ case NoneDistribution => false
+ case _ => throw new UnsupportedOperationException()
+ }
+ override def definition(): PropertyDef[TestNode, _ <: Property[TestNode]] = DistributionDef
+ }
+
+ case object NoneDistribution extends Distribution {
+ override def satisfies(other: Property[TestNode]): Boolean = other match {
+ case _: Distribution => false
+ case _ => throw new UnsupportedOperationException()
+ }
+ override def definition(): PropertyDef[TestNode, _ <: Property[TestNode]] = DistributionDef
+ }
+
+ private object DistributionDef extends PropertyDef[TestNode, Distribution] {
+ override def getProperty(plan: TestNode): Distribution = plan match {
+ case d: DNode => d.getDistribution()
+ case _ =>
+ throw new UnsupportedOperationException()
+ }
+
+ override def getChildrenConstraints(
+ constraint: Property[TestNode],
+ plan: TestNode): Seq[Distribution] = (constraint, plan) match {
+ case (NoneDistribution, p: DNode) => p.children().map(_ => NoneDistribution)
+ case (d: Distribution, p: DNode) => p.getDistributionConstraints(d)
+ case _ => throw new UnsupportedOperationException()
+ }
+ }
+
+ trait Ordering extends Property[TestNode]
+
+ case class SimpleOrdering(keys: Seq[String]) extends Ordering {
+ override def satisfies(other: Property[TestNode]): Boolean = other match {
+ case SimpleOrdering(otherKeys) if keys.size < otherKeys.size => false
+ case SimpleOrdering(otherKeys) =>
+ // (a, b) satisfies (a)
+ otherKeys.zipWithIndex.forall {
+ case (otherKey, index) =>
+ otherKey == keys(index)
+ }
+ case AnyOrdering => true
+ case NoneOrdering => false
+ case _ => throw new UnsupportedOperationException()
+ }
+ override def definition(): PropertyDef[TestNode, _ <: Property[TestNode]] = OrderingDef
+ }
+
+ case object AnyOrdering extends Ordering {
+ override def satisfies(other: Property[TestNode]): Boolean = other match {
+ case SimpleOrdering(_) => false
+ case AnyOrdering => true
+ case NoneOrdering => false
+ case _ => throw new UnsupportedOperationException()
+ }
+ override def definition(): PropertyDef[TestNode, _ <: Property[TestNode]] = OrderingDef
+ }
+
+ case object NoneOrdering extends Ordering {
+ override def satisfies(other: Property[TestNode]): Boolean = other match {
+ case _: Ordering => false
+ case _ => throw new UnsupportedOperationException()
+ }
+ override def definition(): PropertyDef[TestNode, _ <: Property[TestNode]] = OrderingDef
+
+ }
+
+ // FIXME: Handle non-ordering as well as non-distribution
+ private object OrderingDef extends PropertyDef[TestNode, Ordering] {
+ override def getProperty(plan: TestNode): Ordering = plan match {
+ case d: DNode => d.getOrdering()
+ case _ => throw new UnsupportedOperationException()
+ }
+ override def getChildrenConstraints(
+ constraint: Property[TestNode],
+ plan: TestNode): Seq[Ordering] =
+ (constraint, plan) match {
+ case (NoneOrdering, p: DNode) => p.children().map(_ => NoneOrdering)
+ case (o: Ordering, p: DNode) => p.getOrderingConstraints(o)
+ case _ => throw new UnsupportedOperationException()
+ }
+ }
+
+ private class EnforceDistribution(distribution: Distribution) extends CboRule[TestNode] {
+ override def shift(node: TestNode): Iterable[TestNode] = (node, distribution) match {
+ case (d: DNode, HashDistribution(keys)) => List(DExchange(keys, d))
+ case (d: DNode, AnyDistribution) => List(d)
+ case (d: DNode, NoneDistribution) => List.empty
+ case _ =>
+ throw new UnsupportedOperationException()
+ }
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(1)
+ }
+
+ private class EnforceOrdering(ordering: Ordering) extends CboRule[TestNode] {
+ override def shift(node: TestNode): Iterable[TestNode] = (node, ordering) match {
+ case (d: DNode, SimpleOrdering(keys)) => List(DSort(keys, d))
+ case (d: DNode, AnyOrdering) => List(d)
+ case (d: DNode, NoneOrdering) => List.empty
+ case _ => throw new UnsupportedOperationException()
+ }
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(1)
+ }
+
+ private object DistributedPropertyModel extends PropertyModel[TestNode] {
+ override def propertyDefs: Seq[PropertyDef[TestNode, _ <: Property[TestNode]]] =
+ List(DistributionDef, OrderingDef)
+
+ override def newEnforcerRuleFactory(propertyDef: PropertyDef[TestNode, _ <: Property[TestNode]])
+ : EnforcerRuleFactory[TestNode] = new EnforcerRuleFactory[TestNode] {
+ override def newEnforcerRules(constraint: Property[TestNode]): Seq[CboRule[TestNode]] = {
+ constraint match {
+ case distribution: Distribution => List(new EnforceDistribution(distribution))
+ case ordering: Ordering => List(new EnforceOrdering(ordering))
+ case _ => throw new UnsupportedOperationException()
+ }
+ }
+ }
+ }
+
+ trait DNode extends TestNode {
+ def getDistribution(): Distribution
+ def getDistributionConstraints(req: Distribution): Seq[Distribution]
+ def getOrdering(): Ordering
+ def getOrderingConstraints(req: Ordering): Seq[Ordering]
+ def card(): Int
+ }
+
+ case class DLeaf() extends DNode with LeafLike {
+ override def getDistribution(): Distribution = AnyDistribution
+ override def getDistributionConstraints(req: Distribution): Seq[Distribution] = List.empty
+ override def getOrdering(): Ordering = AnyOrdering
+ override def getOrderingConstraints(req: Ordering): Seq[Ordering] = List.empty
+ override def makeCopy(): LeafLike = this
+ override def selfCost(): Long = card()
+ override def card(): Int = 1000
+ }
+
+ case class DAggregate(keys: Seq[String], override val child: TestNode)
+ extends DNode
+ with UnaryLike {
+ override def getDistribution(): Distribution = {
+ val childDistribution = child match {
+ case g: Group => g.propSet.get(DistributionDef)
+ case other => DistributionDef.getProperty(other)
+ }
+ if (childDistribution == NoneDistribution) {
+ return NoneDistribution
+ }
+ if (childDistribution.satisfies(HashDistribution(keys))) {
+ return childDistribution
+ }
+ HashDistribution(keys)
+ }
+
+ override def getDistributionConstraints(req: Distribution): Seq[Distribution] = {
+ if (HashDistribution(keys).satisfies(req)) {
+ return List(HashDistribution(keys))
+ }
+ if (req.satisfies(HashDistribution(keys))) {
+ return List(req)
+ }
+ List(NoneDistribution)
+ }
+ override def getOrdering(): Ordering = AnyOrdering
+ override def getOrderingConstraints(req: Ordering): Seq[Ordering] = List(AnyOrdering)
+ override def withNewChildren(child: TestNode): UnaryLike = copy(child = child)
+ override def selfCost(): Long = 100 * child.asInstanceOf[DNode].card()
+ override def card(): Int = (0.2 * child.asInstanceOf[DNode].card()).toInt
+ }
+
+ case class DPartialAggregate(keys: Seq[String], override val child: TestNode)
+ extends DNode
+ with UnaryLike {
+ override def getDistribution(): Distribution = child match {
+ case g: Group => g.propSet.get(DistributionDef)
+ case other => DistributionDef.getProperty(other)
+ }
+
+ override def getDistributionConstraints(req: Distribution): Seq[Distribution] = List(req)
+ override def getOrdering(): Ordering = AnyOrdering
+ override def getOrderingConstraints(req: Ordering): Seq[Ordering] = List(AnyOrdering)
+ override def withNewChildren(child: TestNode): UnaryLike = copy(child = child)
+ override def selfCost(): Long = 20 * child.asInstanceOf[DNode].card()
+
+ override def card(): Int = (0.2 * child.asInstanceOf[DNode].card()).toInt
+ }
+
+ case class DFinalAggregate(keys: Seq[String], override val child: TestNode)
+ extends DNode
+ with UnaryLike {
+ override def getDistribution(): Distribution = {
+ val childDistribution = child match {
+ case g: Group => g.propSet.get(DistributionDef)
+ case other => DistributionDef.getProperty(other)
+ }
+ if (childDistribution == NoneDistribution) {
+ return NoneDistribution
+ }
+ if (childDistribution.satisfies(HashDistribution(keys))) {
+ return childDistribution
+ }
+ HashDistribution(keys)
+ }
+ override def getDistributionConstraints(req: Distribution): Seq[Distribution] = {
+ if (HashDistribution(keys).satisfies(req)) {
+ return List(HashDistribution(keys))
+ }
+ if (req.satisfies(HashDistribution(keys))) {
+ return List(req)
+ }
+ List(NoneDistribution)
+ }
+ override def getOrdering(): Ordering = AnyOrdering
+ override def getOrderingConstraints(req: Ordering): Seq[Ordering] = List(AnyOrdering)
+ override def withNewChildren(child: TestNode): UnaryLike = copy(child = child)
+ override def selfCost(): Long = 50 * child.asInstanceOf[DNode].card()
+
+ override def card(): Int = (0.2 * child.asInstanceOf[DNode].card()).toInt
+ }
+
+ case class DProject(override val child: TestNode) extends DNode with UnaryLike {
+ override def getDistribution(): Distribution = child match {
+ case g: Group => g.propSet.get(DistributionDef)
+ case other => DistributionDef.getProperty(other)
+ }
+ override def getDistributionConstraints(req: Distribution): Seq[Distribution] = List(req)
+ override def getOrdering(): Ordering = child match {
+ case g: Group => g.propSet.get(OrderingDef)
+ case other => OrderingDef.getProperty(other)
+ }
+ override def getOrderingConstraints(req: Ordering): Seq[Ordering] = List(req)
+ override def withNewChildren(child: TestNode): UnaryLike = copy(child)
+ override def selfCost(): Long = 10 * child.asInstanceOf[DNode].card()
+ override def card(): Int = child.asInstanceOf[DNode].card()
+ }
+
+ case class DExchange(keys: Seq[String], override val child: TestNode)
+ extends DNode
+ with UnaryLike {
+ override def getDistribution(): Distribution = {
+ val childDistribution = child match {
+ case g: Group => g.propSet.get(DistributionDef)
+ case other => DistributionDef.getProperty(other)
+ }
+ if (childDistribution == NoneDistribution) {
+ return NoneDistribution
+ }
+ HashDistribution(keys)
+ }
+ override def getDistributionConstraints(req: Distribution): Seq[Distribution] = List(
+ AnyDistribution)
+ override def getOrdering(): Ordering = AnyOrdering
+ override def getOrderingConstraints(req: Ordering): Seq[Ordering] = Seq(AnyOrdering)
+ override def withNewChildren(child: TestNode): UnaryLike = copy(child = child)
+ override def selfCost(): Long = 50 * child.asInstanceOf[DNode].card()
+ override def card(): Int = child.asInstanceOf[DNode].card()
+ }
+
+ case class DSort(keys: Seq[String], override val child: TestNode) extends DNode with UnaryLike {
+ override def getDistribution(): Distribution = child match {
+ case g: Group => g.propSet.get(DistributionDef)
+ case other => DistributionDef.getProperty(other)
+ }
+ override def getDistributionConstraints(req: Distribution): Seq[Distribution] = List(req)
+ override def getOrdering(): Ordering = {
+ val childOrdering = child match {
+ case g: Group => g.propSet.get(OrderingDef)
+ case other => OrderingDef.getProperty(other)
+ }
+ if (childOrdering.satisfies(SimpleOrdering(keys))) {
+ return childOrdering
+ }
+ SimpleOrdering(keys)
+ }
+ override def getOrderingConstraints(req: Ordering): Seq[Ordering] = Seq(AnyOrdering)
+ override def withNewChildren(child: TestNode): UnaryLike = copy(child = child)
+ override def selfCost(): Long = 40 * child.asInstanceOf[DNode].card()
+ override def card(): Int = child.asInstanceOf[DNode].card()
+ }
+}
diff --git a/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/specific/JoinReorderSuite.scala b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/specific/JoinReorderSuite.scala
new file mode 100644
index 000000000000..e8284649b72b
--- /dev/null
+++ b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/specific/JoinReorderSuite.scala
@@ -0,0 +1,420 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo.specific
+
+import io.glutenproject.cbo.{Cbo, CboConfig, CboSuiteBase}
+import io.glutenproject.cbo.CboConfig.PlannerType
+import io.glutenproject.cbo.rule.{CboRule, Shape, Shapes}
+
+import org.scalatest.funsuite.AnyFunSuite
+
+class ExhaustivePlannerJoinReorderSuite extends JoinReorderSuite {
+ override protected def conf: CboConfig = CboConfig(plannerType = PlannerType.Exhaustive)
+}
+
+class DpPlannerJoinReorderSuite extends JoinReorderSuite {
+ override protected def conf: CboConfig = CboConfig(plannerType = PlannerType.Dp)
+}
+
+abstract class JoinReorderSuite extends AnyFunSuite {
+ import JoinReorderSuite._
+
+ protected def conf: CboConfig
+
+ test("3 way join - dry run") {
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.none())
+ .withNewConfig(_ => conf)
+ val plan = LeftJoin(LeftJoin(Scan(50), Scan(200)), Scan(100))
+ val planner = cbo.newPlanner(plan)
+ val out = planner.plan()
+ assert(out == plan)
+ }
+
+ test("3 way join - reorder") {
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.reuse(List(JoinAssociateRule, JoinCommuteRule)))
+ .withNewConfig(_ => conf)
+ val plan = LeftJoin(LeftJoin(Scan(200), Scan(100)), Scan(30))
+ val planner = cbo.newPlanner(plan)
+ val out = planner.plan()
+ assert(out == LeftJoin(LeftJoin(Scan(30), Scan(200)), Scan(100)))
+ }
+
+ test("5 way join - reorder") {
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.reuse(List(JoinAssociateRule, JoinCommuteRule)))
+ .withNewConfig(_ => conf)
+ val plan =
+ LeftJoin(LeftJoin(Scan(2000), Scan(300)), LeftJoin(LeftJoin(Scan(200), Scan(1000)), Scan(50)))
+ val planner = cbo.newPlanner(plan)
+ val out = planner.plan()
+ assert(
+ out == LeftJoin(
+ LeftJoin(LeftJoin(LeftJoin(Scan(50), Scan(2000)), Scan(1000)), Scan(300)),
+ Scan(200)))
+ }
+
+ // too slow
+ ignore("7 way join - reorder") {
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.reuse(List(JoinAssociateRule, JoinCommuteRule)))
+ .withNewConfig(_ => conf)
+ val plan = LeftJoin(
+ LeftJoin(
+ LeftJoin(Scan(2000), Scan(300)),
+ LeftJoin(LeftJoin(Scan(200), Scan(1000)), Scan(50))),
+ LeftJoin(Scan(700), Scan(3000)))
+ val planner = cbo.newPlanner(plan)
+ val out = planner.plan()
+ throw new UnsupportedOperationException("Not yet implemented")
+ }
+
+ // too slow
+ ignore("9 way join - reorder") {
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.reuse(List(JoinAssociateRule, JoinCommuteRule)))
+ .withNewConfig(_ => conf)
+ val plan = LeftJoin(
+ LeftJoin(
+ LeftJoin(Scan(2000), Scan(300)),
+ LeftJoin(LeftJoin(Scan(200), Scan(1000)), Scan(50))),
+ LeftJoin(LeftJoin(Scan(700), Scan(3000)), LeftJoin(Scan(9000), Scan(1000)))
+ )
+ val planner = cbo.newPlanner(plan)
+ val out = planner.plan()
+ throw new UnsupportedOperationException("Not yet implemented")
+ }
+
+ // too slow
+ ignore("12 way join - reorder") {
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.reuse(List(JoinAssociateRule, JoinCommuteRule)))
+ .withNewConfig(_ => conf)
+ val plan = LeftJoin(
+ LeftJoin(
+ LeftJoin(
+ LeftJoin(Scan(2000), Scan(300)),
+ LeftJoin(LeftJoin(Scan(200), Scan(1000)), Scan(50))),
+ LeftJoin(LeftJoin(Scan(700), Scan(3000)), LeftJoin(Scan(9000), Scan(1000)))
+ ),
+ LeftJoin(LeftJoin(Scan(5000), Scan(1200)), Scan(150))
+ )
+ val planner = cbo.newPlanner(plan)
+ val out = planner.plan()
+ throw new UnsupportedOperationException("Not yet implemented")
+ }
+
+ test("2 way join - reorder, left deep only") {
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.reuse(leftDeepJoinRules(2)))
+ .withNewConfig(_ => conf)
+ val plan = LeftJoin(Scan(200), Scan(30))
+ val planner = cbo.newPlanner(plan)
+ val out = planner.plan()
+ assert(out == LeftDeepJoin(Scan(30), Scan(200)))
+ }
+
+ test("3 way join - reorder, left deep only") {
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.reuse(leftDeepJoinRules(3)))
+ .withNewConfig(_ => conf)
+ val plan = LeftJoin(LeftJoin(Scan(200), Scan(100)), Scan(30))
+ val planner = cbo.newPlanner(plan)
+ val out = planner.plan()
+ assert(out == LeftDeepJoin(LeftDeepJoin(Scan(30), Scan(200)), Scan(100)))
+ }
+
+ test("5 way join - reorder, left deep only") {
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.reuse(leftDeepJoinRules(5)))
+ .withNewConfig(_ => conf)
+ val plan =
+ LeftJoin(LeftJoin(Scan(2000), Scan(300)), LeftJoin(LeftJoin(Scan(200), Scan(1000)), Scan(50)))
+ val planner = cbo.newPlanner(plan)
+ val out = planner.plan()
+ assert(
+ out == LeftDeepJoin(
+ LeftDeepJoin(LeftDeepJoin(LeftDeepJoin(Scan(50), Scan(2000)), Scan(1000)), Scan(300)),
+ Scan(200)))
+ }
+
+ test("7 way join - reorder, left deep only") {
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.reuse(leftDeepJoinRules(7)))
+ .withNewConfig(_ => conf)
+ val plan = LeftJoin(
+ LeftJoin(
+ LeftJoin(Scan(2000), Scan(300)),
+ LeftJoin(LeftJoin(Scan(200), Scan(1000)), Scan(50))),
+ LeftJoin(Scan(700), Scan(3000)))
+ val planner = cbo.newPlanner(plan)
+ val out = planner.plan()
+ assert(
+ out == LeftDeepJoin(
+ LeftDeepJoin(
+ LeftDeepJoin(
+ LeftDeepJoin(LeftDeepJoin(LeftDeepJoin(Scan(50), Scan(3000)), Scan(2000)), Scan(1000)),
+ Scan(700)),
+ Scan(300)),
+ Scan(200)))
+ }
+
+ test("9 way join - reorder, left deep only") {
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.reuse(leftDeepJoinRules(9)))
+ .withNewConfig(_ => conf)
+ val plan = LeftJoin(
+ LeftJoin(
+ LeftJoin(Scan(2000), Scan(300)),
+ LeftJoin(LeftJoin(Scan(200), Scan(1000)), Scan(50))),
+ LeftJoin(LeftJoin(Scan(700), Scan(3000)), LeftJoin(Scan(9000), Scan(1000)))
+ )
+ val planner = cbo.newPlanner(plan)
+ val out = planner.plan()
+ assert(
+ out == LeftDeepJoin(
+ LeftDeepJoin(
+ LeftDeepJoin(
+ LeftDeepJoin(
+ LeftDeepJoin(
+ LeftDeepJoin(
+ LeftDeepJoin(LeftDeepJoin(Scan(50), Scan(9000)), Scan(3000)),
+ Scan(2000)),
+ Scan(1000)),
+ Scan(1000)),
+ Scan(700)),
+ Scan(300)
+ ),
+ Scan(200)
+ ))
+ }
+
+ // too slow
+ ignore("12 way join - reorder, left deep only") {
+ val cbo =
+ Cbo[TestNode](
+ CostModelImpl,
+ PlanModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ CboRule.Factory.reuse(leftDeepJoinRules(12)))
+ .withNewConfig(_ => conf)
+ val plan = LeftJoin(
+ LeftJoin(
+ LeftJoin(
+ LeftJoin(Scan(2000), Scan(300)),
+ LeftJoin(LeftJoin(Scan(200), Scan(1000)), Scan(50))),
+ LeftJoin(LeftJoin(Scan(700), Scan(3000)), LeftJoin(Scan(9000), Scan(1000)))
+ ),
+ LeftJoin(LeftJoin(Scan(5000), Scan(1200)), Scan(150))
+ )
+ val planner = cbo.newPlanner(plan)
+ val out = planner.plan()
+ throw new UnsupportedOperationException("Not yet implemented")
+ }
+}
+
+object JoinReorderSuite extends CboSuiteBase {
+
+ object JoinAssociateRule extends CboRule[TestNode] {
+ override def shift(node: TestNode): Iterable[TestNode] = node match {
+ case LeftJoin(LeftJoin(a, b), c) => List(LeftJoin(a, LeftJoin(b, c)))
+ case LeftJoin(a, LeftJoin(b, c)) => List(LeftJoin(LeftJoin(a, b), c))
+ case other => List.empty
+ }
+
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(2)
+ }
+
+ object JoinCommuteRule extends CboRule[TestNode] {
+ override def shift(node: TestNode): Iterable[TestNode] = node match {
+ case LeftJoin(a, b) => List(LeftJoin(b, a))
+ case other => List.empty
+ }
+
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(1)
+ }
+
+ trait KnownCardNode extends TestNode {
+ def card(): Long
+ }
+
+ abstract class LeftJoinBase(left: TestNode, right: TestNode)
+ extends BinaryLike
+ with KnownCardNode {
+ private def leftCard() = left match {
+ case l: KnownCardNode => l.card()
+ case _ => -1L
+ }
+
+ private def rightCard() = right match {
+ case r: KnownCardNode => r.card()
+ case _ => -1L
+ }
+
+ override def selfCost(): Long = leftCard() * leftCard() * rightCard()
+
+ override def card(): Long = {
+ leftCard().min(rightCard()) + 1
+ }
+ }
+
+ case class LeftJoin(left: TestNode, right: TestNode) extends LeftJoinBase(left, right) {
+ override def withNewChildren(left: TestNode, right: TestNode): BinaryLike =
+ copy(left = left, right = right)
+ }
+
+ case class Scan(card: Long) extends LeafLike with KnownCardNode {
+ override def selfCost(): Long = card
+ override def makeCopy(): LeafLike = copy()
+ }
+
+ // Rules and node types for left deep join
+
+ private def leftDeepJoinRules(expandThreshold: Int): List[CboRule[TestNode]] = {
+ List(
+ MultiJoinRule,
+ MultiJoinMergeRule,
+ MultiJoinToLeftDeepJoinRule(expandThreshold),
+ LeftDeepJoinCommuteRule2,
+ LeftDeepJoinCommuteRule3
+ )
+ }
+
+ object MultiJoinRule extends CboRule[TestNode] {
+ override def shift(node: TestNode): Iterable[TestNode] = node match {
+ case LeftJoin(left, right) => List(MultiJoin(List(left, right)))
+ case _ => List.empty
+ }
+
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(1)
+ }
+
+ object MultiJoinMergeRule extends CboRule[TestNode] {
+
+ override def shift(node: TestNode): Iterable[TestNode] = {
+ val out = shift0(node)
+ out
+ }
+
+ private def shift0(node: TestNode): Iterable[TestNode] = node match {
+ case MultiJoin(children) =>
+ val newChildren = children.flatMap {
+ case MultiJoin(c) => c
+ case n => List(n)
+ }
+ List(MultiJoin(newChildren))
+ case _ => List.empty
+ }
+
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(2)
+ }
+
+ case class MultiJoinToLeftDeepJoinRule(expandThreshold: Int) extends CboRule[TestNode] {
+ override def shift(node: TestNode): Iterable[TestNode] = node match {
+ case MultiJoin(children) if children.size >= expandThreshold =>
+ List(children.reduce((a, b) => LeftDeepJoin(a, b)))
+ case _ => List.empty
+ }
+
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(1)
+ }
+
+ object LeftDeepJoinCommuteRule2 extends CboRule[TestNode] {
+ override def shift(node: TestNode): Iterable[TestNode] = node match {
+ case LeftDeepJoin(s1 @ Scan(a), s2 @ Scan(b)) => List(LeftDeepJoin(s2, s1))
+ case _ => List.empty
+ }
+
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(2)
+ }
+
+ object LeftDeepJoinCommuteRule3 extends CboRule[TestNode] {
+ override def shift(node: TestNode): Iterable[TestNode] = node match {
+ case LeftDeepJoin(LeftDeepJoin(a, s1 @ Scan(_)), s2 @ Scan(_)) =>
+ List(LeftDeepJoin(LeftDeepJoin(a, s2), s1))
+ case _ => List.empty
+ }
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(3)
+ }
+
+ case class MultiJoin(override val children: Seq[TestNode]) extends TestNode {
+ override def withNewChildren(newChildren: Seq[TestNode]): MultiJoin = copy(newChildren)
+ override def selfCost(): Long = Long.MaxValue
+ }
+
+ case class LeftDeepJoin(left: TestNode, right: TestNode) extends LeftJoinBase(left, right) {
+ override def withNewChildren(left: TestNode, right: TestNode): BinaryLike =
+ copy(left = left, right = right)
+ }
+}
diff --git a/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/util/IndexDisjointSetSuite.scala b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/util/IndexDisjointSetSuite.scala
new file mode 100644
index 000000000000..ca14a1815441
--- /dev/null
+++ b/gluten-cbo/common/src/test/scala/io/glutenproject/cbo/util/IndexDisjointSetSuite.scala
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.cbo.util
+
+import org.scalatest.funsuite.AnyFunSuite
+
+class IndexDisjointSetSuite extends AnyFunSuite {
+ test("Size") {
+ val set = IndexDisjointSet[String]()
+ set.grow()
+ set.grow()
+ assert(set.size() == 2)
+ }
+
+ test("Union, Set") {
+ val set = IndexDisjointSet[String]()
+ set.grow()
+ set.grow()
+ set.grow()
+ set.grow()
+ set.grow()
+
+ assert(set.setOf(0) == Set(0))
+ assert(set.setOf(1) == Set(1))
+ assert(set.setOf(2) == Set(2))
+ assert(set.setOf(3) == Set(3))
+ assert(set.setOf(4) == Set(4))
+
+ set.forward(set.find(2), set.find(3))
+ assert(set.find(0) == 0)
+ assert(set.find(1) == 1)
+ assert(set.find(2) == 3)
+ assert(set.find(3) == 3)
+ assert(set.find(4) == 4)
+
+ assert(set.setOf(0) == Set(0))
+ assert(set.setOf(1) == Set(1))
+ assert(set.setOf(2) == Set(2, 3))
+ assert(set.setOf(3) == Set(2, 3))
+ assert(set.setOf(4) == Set(4))
+
+ set.forward(set.find(4), set.find(3))
+ assert(set.find(0) == 0)
+ assert(set.find(1) == 1)
+ assert(set.find(2) == 3)
+ assert(set.find(3) == 3)
+ assert(set.find(4) == 3)
+
+ assert(set.setOf(0) == Set(0))
+ assert(set.setOf(1) == Set(1))
+ assert(set.setOf(2) == Set(2, 3, 4))
+ assert(set.setOf(3) == Set(2, 3, 4))
+ assert(set.setOf(4) == Set(2, 3, 4))
+
+ set.forward(set.find(3), set.find(0))
+ assert(set.find(0) == 0)
+ assert(set.find(1) == 1)
+ assert(set.find(2) == 0)
+ assert(set.find(3) == 0)
+ assert(set.find(4) == 0)
+
+ assert(set.setOf(0) == Set(0, 2, 3, 4))
+ assert(set.setOf(1) == Set(1))
+ assert(set.setOf(2) == Set(0, 2, 3, 4))
+ assert(set.setOf(3) == Set(0, 2, 3, 4))
+ assert(set.setOf(4) == Set(0, 2, 3, 4))
+
+ set.forward(set.find(2), set.find(1))
+
+ assert(set.find(0) == 1)
+ assert(set.find(1) == 1)
+ assert(set.find(2) == 1)
+ assert(set.find(3) == 1)
+ assert(set.find(4) == 1)
+
+ assert(set.setOf(0) == Set(0, 1, 2, 3, 4))
+ assert(set.setOf(1) == Set(0, 1, 2, 3, 4))
+ assert(set.setOf(2) == Set(0, 1, 2, 3, 4))
+ assert(set.setOf(3) == Set(0, 1, 2, 3, 4))
+ assert(set.setOf(4) == Set(0, 1, 2, 3, 4))
+ }
+}
diff --git a/gluten-cbo/planner/pom.xml b/gluten-cbo/planner/pom.xml
new file mode 100644
index 000000000000..2054e8537ecf
--- /dev/null
+++ b/gluten-cbo/planner/pom.xml
@@ -0,0 +1,42 @@
+
+ 4.0.0
+
+ io.glutenproject
+ gluten-cbo
+ 1.2.0-SNAPSHOT
+
+ gluten-cbo-planner
+ Gluten Cbo Planner
+
+
+
+ io.glutenproject
+ gluten-cbo-common
+ ${project.version}
+
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ provided
+
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ provided
+
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ test-jar
+ test
+
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ test-jar
+ test
+
+
+
diff --git a/gluten-cbo/planner/src/main/scala/io/glutenproject/planner/.gitkeep b/gluten-cbo/planner/src/main/scala/io/glutenproject/planner/.gitkeep
new file mode 100644
index 000000000000..54d9eab911a0
--- /dev/null
+++ b/gluten-cbo/planner/src/main/scala/io/glutenproject/planner/.gitkeep
@@ -0,0 +1 @@
+The module is kept for adding code in future to test Vanilla Spark with CBO without Gluten.
diff --git a/gluten-cbo/planner/src/test/scala/io/glutenproject/planner/.gitkeep b/gluten-cbo/planner/src/test/scala/io/glutenproject/planner/.gitkeep
new file mode 100644
index 000000000000..54d9eab911a0
--- /dev/null
+++ b/gluten-cbo/planner/src/test/scala/io/glutenproject/planner/.gitkeep
@@ -0,0 +1 @@
+The module is kept for adding code in future to test Vanilla Spark with CBO without Gluten.
diff --git a/gluten-cbo/pom.xml b/gluten-cbo/pom.xml
new file mode 100644
index 000000000000..5d32c7b4fb69
--- /dev/null
+++ b/gluten-cbo/pom.xml
@@ -0,0 +1,111 @@
+
+
+
+ 4.0.0
+
+ io.glutenproject
+ gluten-parent
+ 1.2.0-SNAPSHOT
+
+ gluten-cbo
+ pom
+ Gluten Cbo
+
+
+ common
+ planner
+
+
+
+
+ org.scalacheck
+ scalacheck_${scala.binary.version}
+ 1.13.5
+ test
+
+
+ org.scala-lang
+ scala-library
+ ${scala.version}
+ provided
+
+
+ org.scalatest
+ scalatest_${scala.binary.version}
+ test
+
+
+ org.scalatestplus
+ scalatestplus-mockito_2.12
+ 1.0.0-M2
+ test
+
+
+ org.scalatestplus
+ scalatestplus-scalacheck_2.12
+ 3.1.0.0-RC2
+ test
+
+
+
+ target/scala-${scala.binary.version}/classes
+ target/scala-${scala.binary.version}/test-classes
+
+
+ org.apache.maven.plugins
+ maven-resources-plugin
+
+
+ net.alchim31.maven
+ scala-maven-plugin
+
+
+ org.apache.maven.plugins
+ maven-compiler-plugin
+
+
+ org.scalastyle
+ scalastyle-maven-plugin
+
+
+ org.apache.maven.plugins
+ maven-checkstyle-plugin
+
+
+ com.diffplug.spotless
+ spotless-maven-plugin
+
+
+ org.scalatest
+ scalatest-maven-plugin
+ ${scalatest-maven-plugin.version}
+
+
+ org.apache.maven.plugins
+ maven-jar-plugin
+ ${maven.jar.plugin}
+
+
+ prepare-test-jar
+ test-compile
+
+ test-jar
+
+
+
+
+
+
+
diff --git a/gluten-core/pom.xml b/gluten-core/pom.xml
index a420b50f89a8..e6206e948711 100644
--- a/gluten-core/pom.xml
+++ b/gluten-core/pom.xml
@@ -34,6 +34,12 @@
${project.version}
compile
+
+ io.glutenproject
+ gluten-cbo-common
+ ${project.version}
+ compile
+
org.apache.spark
diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/LimitTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/execution/LimitTransformer.scala
index e9320cc5eab1..028ea854553d 100644
--- a/gluten-core/src/main/scala/io/glutenproject/execution/LimitTransformer.scala
+++ b/gluten-core/src/main/scala/io/glutenproject/execution/LimitTransformer.scala
@@ -60,6 +60,12 @@ case class LimitTransformer(child: SparkPlan, offset: Long, count: Long)
case c: TransformSupport => c.doTransform(context).root
case _ => null
}
+ // If ACBO is enabled and is capable to move limit operators up and down among the plan nodes,
+ // It might become possible that an independent limit gets to be transited to a order-by-limit.
+ // In that case, we should tuning on the validation procedure. Either to move limit validation
+ // To ACBO, or re-validate it in ACBO, or add properties or rules in ACBO to avoid such moves.
+ //
+ // It's not a issue for now since ACBO doesn't do such moves.
val relNode = getRelNode(context, operatorId, offset, count, child.output, input, true)
doNativeValidation(context, relNode)
diff --git a/gluten-core/src/main/scala/io/glutenproject/extension/ColumnarOverrides.scala b/gluten-core/src/main/scala/io/glutenproject/extension/ColumnarOverrides.scala
index e3478720a4ad..279ce2fb2d43 100644
--- a/gluten-core/src/main/scala/io/glutenproject/extension/ColumnarOverrides.scala
+++ b/gluten-core/src/main/scala/io/glutenproject/extension/ColumnarOverrides.scala
@@ -20,6 +20,7 @@ import io.glutenproject.{GlutenConfig, GlutenSparkExtensionsInjector}
import io.glutenproject.backendsapi.BackendsApiManager
import io.glutenproject.extension.columnar._
import io.glutenproject.extension.columnar.MiscColumnarRules.{RemoveGlutenTableCacheColumnarToRow, RemoveTopmostColumnarToRow, TransformPostOverrides, TransformPreOverrides}
+import io.glutenproject.extension.columnar.transform._
import io.glutenproject.metrics.GlutenTimeMetric
import io.glutenproject.utils.{LogLevelUtil, PhysicalPlanSelector, PlanUtil}
@@ -223,6 +224,19 @@ case class ColumnarOverrideRules(session: SparkSession)
* the plan will be breakdown and decided to be fallen back or not.
*/
private def transformRules(outputsColumnar: Boolean): List[SparkSession => Rule[SparkPlan]] = {
+
+ def maybeCbo(outputsColumnar: Boolean): List[SparkSession => Rule[SparkPlan]] = {
+ if (GlutenConfig.getConf.enableAdvancedCbo) {
+ return List(
+ (_: SparkSession) => TransformPreOverrides(List(ImplementFilter()), List.empty),
+ (session: SparkSession) => EnumeratedTransform(session, outputsColumnar),
+ (_: SparkSession) => RemoveTransitions,
+ (_: SparkSession) => TransformPreOverrides(List.empty, List(ImplementAggregate()))
+ )
+ }
+ List((_: SparkSession) => TransformPreOverrides())
+ }
+
List(
(_: SparkSession) => RemoveTransitions,
(spark: SparkSession) => FallbackOnANSIMode(spark),
@@ -235,23 +249,18 @@ case class ColumnarOverrideRules(session: SparkSession)
(spark: SparkSession) => MergeTwoPhasesHashBaseAggregate(spark),
(_: SparkSession) => rewriteSparkPlanRule(),
(_: SparkSession) => AddTransformHintRule(),
- (_: SparkSession) => FallbackBloomFilterAggIfNeeded(),
- // We are planning to merge rule "TransformPreOverrides" and "InsertTransitions"
- // together. So temporarily have both `InsertTransitions` and `RemoveTransitions`
- // set in there to make sure the rule list (after insert transitions) is compatible
- // with input plans that have C2Rs/R2Cs inserted.
- (_: SparkSession) => TransformPreOverrides(),
- (_: SparkSession) => InsertTransitions(outputsColumnar),
- (_: SparkSession) => RemoveTransitions,
+ (_: SparkSession) => FallbackBloomFilterAggIfNeeded()
+ ) :::
+ maybeCbo(outputsColumnar) :::
+ List(
(_: SparkSession) => RemoveNativeWriteFilesSortAndProject(),
(spark: SparkSession) => RewriteTransformer(spark),
(_: SparkSession) => EnsureLocalSortRequirements,
(_: SparkSession) => CollapseProjectExecTransformer
) :::
BackendsApiManager.getSparkPlanExecApiInstance.genExtendedColumnarTransformRules() :::
- SparkRuleUtil.extendedColumnarRules(
- session,
- GlutenConfig.getConf.extendedColumnarTransformRules) :::
+ SparkRuleUtil
+ .extendedColumnarRules(session, GlutenConfig.getConf.extendedColumnarTransformRules) :::
List((_: SparkSession) => InsertTransitions(outputsColumnar))
}
diff --git a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/ColumnarTransitions.scala b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/ColumnarTransitions.scala
index 0247431700e4..2d82d213ab9d 100644
--- a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/ColumnarTransitions.scala
+++ b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/ColumnarTransitions.scala
@@ -17,7 +17,7 @@
package io.glutenproject.extension.columnar
import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.execution.{ApplyColumnarRulesAndInsertTransitions, ColumnarToRowExec, RowToColumnarExec, SparkPlan}
+import org.apache.spark.sql.execution.{ApplyColumnarRulesAndInsertTransitions, ColumnarToRowExec, ColumnarToRowTransition, RowToColumnarExec, RowToColumnarTransition, SparkPlan}
/** See rule code from vanilla Spark: [[ApplyColumnarRulesAndInsertTransitions]]. */
case class InsertTransitions(outputsColumnar: Boolean) extends Rule[SparkPlan] {
@@ -37,9 +37,10 @@ case class InsertTransitions(outputsColumnar: Boolean) extends Rule[SparkPlan] {
}
object RemoveTransitions extends Rule[SparkPlan] {
+ import ColumnarTransitions._
override def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
- case ColumnarToRowExec(child) => child
- case RowToColumnarExec(child) => child
+ case ColumnarToRowLike(child) => child
+ case RowToColumnarLike(child) => child
}
}
@@ -47,4 +48,26 @@ object ColumnarTransitions {
def insertTransitions(plan: SparkPlan, outputsColumnar: Boolean): SparkPlan = {
InsertTransitions(outputsColumnar).apply(plan)
}
+
+ // Extractor for Spark/Gluten's C2R
+ object ColumnarToRowLike {
+ def unapply(plan: SparkPlan): Option[SparkPlan] = {
+ plan match {
+ case c2r: ColumnarToRowTransition =>
+ Some(c2r.child)
+ case _ => None
+ }
+ }
+ }
+
+ // Extractor for Spark/Gluten's R2C
+ object RowToColumnarLike {
+ def unapply(plan: SparkPlan): Option[SparkPlan] = {
+ plan match {
+ case c2r: RowToColumnarTransition =>
+ Some(c2r.child)
+ case _ => None
+ }
+ }
+ }
}
diff --git a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/EnumeratedTransform.scala b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/EnumeratedTransform.scala
new file mode 100644
index 000000000000..42ffec1555f4
--- /dev/null
+++ b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/EnumeratedTransform.scala
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.extension.columnar
+
+import io.glutenproject.cbo.property.PropertySet
+import io.glutenproject.cbo.rule.{CboRule, Shape, Shapes}
+import io.glutenproject.extension.columnar.transform.{ImplementExchange, ImplementJoin, ImplementOthers, ImplementSingleNode}
+import io.glutenproject.planner.GlutenOptimization
+import io.glutenproject.planner.property.GlutenProperties
+import io.glutenproject.utils.LogLevelUtil
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.SparkPlan
+
+case class EnumeratedTransform(session: SparkSession, outputsColumnar: Boolean)
+ extends Rule[SparkPlan]
+ with LogLevelUtil {
+ import EnumeratedTransform._
+
+ private val cboRules = List(
+ CboImplement(ImplementOthers()),
+ CboImplement(ImplementExchange()),
+ CboImplement(ImplementJoin())
+ )
+
+ private val optimization = GlutenOptimization(cboRules)
+
+ private val reqConvention = GlutenProperties.Conventions.ANY
+ private val altConventions =
+ Seq(GlutenProperties.Conventions.GLUTEN_COLUMNAR, GlutenProperties.Conventions.ROW_BASED)
+
+ override def apply(plan: SparkPlan): SparkPlan = {
+ val constraintSet = PropertySet(List(GlutenProperties.Schemas.ANY, reqConvention))
+ val altConstraintSets =
+ altConventions.map(altConv => PropertySet(List(GlutenProperties.Schemas.ANY, altConv)))
+ val planner = optimization.newPlanner(plan, constraintSet, altConstraintSets)
+ val out = planner.plan()
+ out
+ }
+}
+
+object EnumeratedTransform {
+ private case class CboImplement(delegate: ImplementSingleNode) extends CboRule[SparkPlan] {
+ override def shift(node: SparkPlan): Iterable[SparkPlan] = {
+ val out = List(delegate.impl(node))
+ out
+ }
+
+ override def shape(): Shape[SparkPlan] = Shapes.fixedHeight(1)
+ }
+}
diff --git a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/MiscColumnarRules.scala b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/MiscColumnarRules.scala
index d21956bfcc97..b2622f31d4d6 100644
--- a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/MiscColumnarRules.scala
+++ b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/MiscColumnarRules.scala
@@ -16,522 +16,45 @@
*/
package io.glutenproject.extension.columnar
-import io.glutenproject.GlutenConfig
import io.glutenproject.backendsapi.BackendsApiManager
-import io.glutenproject.exception.GlutenNotSupportException
-import io.glutenproject.execution._
-import io.glutenproject.expression.ExpressionConverter
-import io.glutenproject.extension.{ColumnarToRowLike, GlutenPlan}
-import io.glutenproject.sql.shims.SparkShimLoader
+import io.glutenproject.extension.ColumnarToRowLike
+import io.glutenproject.extension.columnar.transform.{ImplementAggregate, ImplementExchange, ImplementFilter, ImplementJoin, ImplementOthers, ImplementSingleNode}
import io.glutenproject.utils.{LogLevelUtil, PlanUtil}
-import org.apache.spark.api.python.EvalPythonExecTransformer
-import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.catalyst.expressions.Expression
-import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide}
-import org.apache.spark.sql.catalyst.plans.{LeftOuter, LeftSemi, RightOuter}
import org.apache.spark.sql.catalyst.rules.{PlanChangeLogger, Rule}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.BroadcastQueryStageExec
-import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
-import org.apache.spark.sql.execution.datasources.WriteFilesExec
-import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, FileScan}
-import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike, ShuffleExchangeExec, ShuffleExchangeLike}
-import org.apache.spark.sql.execution.joins._
-import org.apache.spark.sql.execution.python.EvalPythonExec
-import org.apache.spark.sql.execution.window.WindowExec
-import org.apache.spark.sql.hive.HiveTableScanExecTransformer
+import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ShuffleExchangeLike}
object MiscColumnarRules {
object TransformPreOverrides {
- // Sub-rules of TransformPreOverrides.
-
- // Aggregation transformation.
- private case class AggregationTransformRule() extends Rule[SparkPlan] with LogLevelUtil {
- override def apply(plan: SparkPlan): SparkPlan = plan match {
- case plan if TransformHints.isNotTransformable(plan) =>
- plan
- case agg: HashAggregateExec =>
- genHashAggregateExec(agg)
- case other => other
- }
-
- /**
- * Generate a plan for hash aggregation.
- *
- * @param plan
- * : the original Spark plan.
- * @return
- * the actually used plan for execution.
- */
- private def genHashAggregateExec(plan: HashAggregateExec): SparkPlan = {
- if (TransformHints.isNotTransformable(plan)) {
- return plan
- }
-
- val aggChild = plan.child
-
- def transformHashAggregate(): GlutenPlan = {
- BackendsApiManager.getSparkPlanExecApiInstance
- .genHashAggregateExecTransformer(
- plan.requiredChildDistributionExpressions,
- plan.groupingExpressions,
- plan.aggregateExpressions,
- plan.aggregateAttributes,
- plan.initialInputBufferOffset,
- plan.resultExpressions,
- aggChild
- )
- }
-
- // If child's output is empty, fallback or offload both the child and aggregation.
- if (
- plan.child.output.isEmpty && BackendsApiManager.getSettings.fallbackAggregateWithChild()
- ) {
- aggChild match {
- case _: TransformSupport =>
- // If the child is transformable, transform aggregation as well.
- logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
- transformHashAggregate()
- case p: SparkPlan if PlanUtil.isGlutenTableCache(p) =>
- transformHashAggregate()
- case _ =>
- // If the child is not transformable, do not transform the agg.
- TransformHints.tagNotTransformable(plan, "child output schema is empty")
- plan
- }
- } else {
- logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
- transformHashAggregate()
- }
- }
- }
-
- // Exchange transformation.
- private case class ExchangeTransformRule() extends Rule[SparkPlan] with LogLevelUtil {
- override def apply(plan: SparkPlan): SparkPlan = plan match {
- case plan if TransformHints.isNotTransformable(plan) =>
- plan
- case plan: ShuffleExchangeExec =>
- logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
- val child = plan.child
- if (
- (child.supportsColumnar || GlutenConfig.getConf.enablePreferColumnar) &&
- BackendsApiManager.getSettings.supportColumnarShuffleExec()
- ) {
- BackendsApiManager.getSparkPlanExecApiInstance.genColumnarShuffleExchange(plan, child)
- } else {
- plan.withNewChildren(Seq(child))
- }
- case plan: BroadcastExchangeExec =>
- val child = plan.child
- logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
- ColumnarBroadcastExchangeExec(plan.mode, child)
- case other => other
- }
- }
-
- // Join transformation.
- private case class JoinTransformRule() extends Rule[SparkPlan] with LogLevelUtil {
-
- /**
- * Get the build side supported by the execution of vanilla Spark.
- *
- * @param plan
- * : shuffled hash join plan
- * @return
- * the supported build side
- */
- private def getSparkSupportedBuildSide(plan: ShuffledHashJoinExec): BuildSide = {
- plan.joinType match {
- case LeftOuter | LeftSemi => BuildRight
- case RightOuter => BuildLeft
- case _ => plan.buildSide
- }
- }
-
- override def apply(plan: SparkPlan): SparkPlan = {
- if (TransformHints.isNotTransformable(plan)) {
- logDebug(s"Columnar Processing for ${plan.getClass} is under row guard.")
- plan match {
- case shj: ShuffledHashJoinExec =>
- if (BackendsApiManager.getSettings.recreateJoinExecOnFallback()) {
- // Because we manually removed the build side limitation for LeftOuter, LeftSemi and
- // RightOuter, need to change the build side back if this join fallback into vanilla
- // Spark for execution.
- return ShuffledHashJoinExec(
- shj.leftKeys,
- shj.rightKeys,
- shj.joinType,
- getSparkSupportedBuildSide(shj),
- shj.condition,
- shj.left,
- shj.right,
- shj.isSkewJoin
- )
- } else {
- return shj
- }
- case p =>
- return p
- }
- }
- plan match {
- case plan: ShuffledHashJoinExec =>
- val left = plan.left
- val right = plan.right
- logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
- BackendsApiManager.getSparkPlanExecApiInstance
- .genShuffledHashJoinExecTransformer(
- plan.leftKeys,
- plan.rightKeys,
- plan.joinType,
- plan.buildSide,
- plan.condition,
- left,
- right,
- plan.isSkewJoin)
- case plan: SortMergeJoinExec =>
- val left = plan.left
- val right = plan.right
- logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
- SortMergeJoinExecTransformer(
- plan.leftKeys,
- plan.rightKeys,
- plan.joinType,
- plan.condition,
- left,
- right,
- plan.isSkewJoin)
- case plan: BroadcastHashJoinExec =>
- val left = plan.left
- val right = plan.right
- BackendsApiManager.getSparkPlanExecApiInstance
- .genBroadcastHashJoinExecTransformer(
- plan.leftKeys,
- plan.rightKeys,
- plan.joinType,
- plan.buildSide,
- plan.condition,
- left,
- right,
- isNullAwareAntiJoin = plan.isNullAwareAntiJoin)
- case plan: CartesianProductExec =>
- val left = plan.left
- val right = plan.right
- BackendsApiManager.getSparkPlanExecApiInstance
- .genCartesianProductExecTransformer(left, right, plan.condition)
- case plan: BroadcastNestedLoopJoinExec =>
- val left = plan.left
- val right = plan.right
- BackendsApiManager.getSparkPlanExecApiInstance
- .genBroadcastNestedLoopJoinExecTransformer(
- left,
- right,
- plan.buildSide,
- plan.joinType,
- plan.condition)
- case other => other
- }
- }
-
- }
-
- // Filter transformation.
- private case class FilterTransformRule() extends Rule[SparkPlan] with LogLevelUtil {
- private val replace = new ReplaceSingleNode()
-
- override def apply(plan: SparkPlan): SparkPlan = plan match {
- case filter: FilterExec =>
- genFilterExec(filter)
- case other => other
- }
-
- /**
- * Generate a plan for filter.
- *
- * @param plan
- * : the original Spark plan.
- * @return
- * the actually used plan for execution.
- */
- private def genFilterExec(plan: FilterExec): SparkPlan = {
- if (TransformHints.isNotTransformable(plan)) {
- return plan
- }
-
- // FIXME: Filter push-down should be better done by Vanilla Spark's planner or by
- // a individual rule.
- val scan = plan.child
- // Push down the left conditions in Filter into FileSourceScan.
- val newChild: SparkPlan = scan match {
- case _: FileSourceScanExec | _: BatchScanExec =>
- if (TransformHints.isTransformable(scan)) {
- val newScan = FilterHandler.applyFilterPushdownToScan(plan)
- newScan match {
- case ts: TransformSupport if ts.doValidate().isValid => ts
- // TODO remove the call
- case _ => replace.replaceWithTransformerPlan(scan)
- }
- } else {
- replace.replaceWithTransformerPlan(scan)
- }
- case _ => replace.replaceWithTransformerPlan(plan.child)
- }
- logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
- BackendsApiManager.getSparkPlanExecApiInstance
- .genFilterExecTransformer(plan.condition, newChild)
- }
- }
-
- // Other transformations.
- private case class RegularTransformRule() extends Rule[SparkPlan] with LogLevelUtil {
- private val replace = new ReplaceSingleNode()
-
- override def apply(plan: SparkPlan): SparkPlan = replace.replaceWithTransformerPlan(plan)
- }
-
- // Utility to replace single node within transformed Gluten node.
- // Children will be preserved as they are as children of the output node.
- class ReplaceSingleNode() extends LogLevelUtil with Logging {
-
- def replaceWithTransformerPlan(p: SparkPlan): SparkPlan = {
- val plan = p
- if (TransformHints.isNotTransformable(plan)) {
- logDebug(s"Columnar Processing for ${plan.getClass} is under row guard.")
- plan match {
- case plan: BatchScanExec =>
- return applyScanNotTransformable(plan)
- case plan: FileSourceScanExec =>
- return applyScanNotTransformable(plan)
- case plan if HiveTableScanExecTransformer.isHiveTableScan(plan) =>
- return applyScanNotTransformable(plan)
- case p =>
- return p
- }
- }
- plan match {
- case plan: BatchScanExec =>
- applyScanTransformer(plan)
- case plan: FileSourceScanExec =>
- applyScanTransformer(plan)
- case plan if HiveTableScanExecTransformer.isHiveTableScan(plan) =>
- applyScanTransformer(plan)
- case plan: CoalesceExec =>
- logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
- CoalesceExecTransformer(plan.numPartitions, plan.child)
- case plan: ProjectExec =>
- val columnarChild = plan.child
- logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
- ProjectExecTransformer(plan.projectList, columnarChild)
- case plan: SortAggregateExec =>
- logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
- BackendsApiManager.getSparkPlanExecApiInstance
- .genHashAggregateExecTransformer(
- plan.requiredChildDistributionExpressions,
- plan.groupingExpressions,
- plan.aggregateExpressions,
- plan.aggregateAttributes,
- plan.initialInputBufferOffset,
- plan.resultExpressions,
- plan.child match {
- case sort: SortExecTransformer if !sort.global =>
- sort.child
- case sort: SortExec if !sort.global =>
- sort.child
- case _ => plan.child
- }
- )
- case plan: ObjectHashAggregateExec =>
- val child = plan.child
- logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
- BackendsApiManager.getSparkPlanExecApiInstance
- .genHashAggregateExecTransformer(
- plan.requiredChildDistributionExpressions,
- plan.groupingExpressions,
- plan.aggregateExpressions,
- plan.aggregateAttributes,
- plan.initialInputBufferOffset,
- plan.resultExpressions,
- child
- )
- case plan: UnionExec =>
- val children = plan.children
- logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
- ColumnarUnionExec(children)
- case plan: ExpandExec =>
- val child = plan.child
- logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
- ExpandExecTransformer(plan.projections, plan.output, child)
- case plan: WriteFilesExec =>
- val child = plan.child
- logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
- val writeTransformer = WriteFilesExecTransformer(
- child,
- plan.fileFormat,
- plan.partitionColumns,
- plan.bucketSpec,
- plan.options,
- plan.staticPartitions)
- BackendsApiManager.getSparkPlanExecApiInstance.createColumnarWriteFilesExec(
- writeTransformer,
- plan.fileFormat,
- plan.partitionColumns,
- plan.bucketSpec,
- plan.options,
- plan.staticPartitions
- )
- case plan: SortExec =>
- val child = plan.child
- logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
- SortExecTransformer(plan.sortOrder, plan.global, child, plan.testSpillFrequency)
- case plan: TakeOrderedAndProjectExec =>
- logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
- val child = plan.child
- val (limit, offset) = SparkShimLoader.getSparkShims.getLimitAndOffsetFromTopK(plan)
- TakeOrderedAndProjectExecTransformer(
- limit,
- plan.sortOrder,
- plan.projectList,
- child,
- offset)
- case plan: WindowExec =>
- WindowExecTransformer(
- plan.windowExpression,
- plan.partitionSpec,
- plan.orderSpec,
- plan.child)
- case plan: GlobalLimitExec =>
- logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
- val child = plan.child
- val (limit, offset) =
- SparkShimLoader.getSparkShims.getLimitAndOffsetFromGlobalLimit(plan)
- LimitTransformer(child, offset, limit)
- case plan: LocalLimitExec =>
- logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
- val child = plan.child
- LimitTransformer(child, 0L, plan.limit)
- case plan: GenerateExec =>
- logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
- val child = plan.child
- BackendsApiManager.getSparkPlanExecApiInstance.genGenerateTransformer(
- plan.generator,
- plan.requiredChildOutput,
- plan.outer,
- plan.generatorOutput,
- child)
- case plan: EvalPythonExec =>
- logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
- val child = plan.child
- EvalPythonExecTransformer(plan.udfs, plan.resultAttrs, child)
- case p if !p.isInstanceOf[GlutenPlan] =>
- logDebug(s"Transformation for ${p.getClass} is currently not supported.")
- val children = plan.children
- p.withNewChildren(children)
- case other => other
- }
- }
-
- private def applyScanNotTransformable(plan: SparkPlan): SparkPlan = plan match {
- case plan: FileSourceScanExec =>
- val newPartitionFilters =
- ExpressionConverter.transformDynamicPruningExpr(plan.partitionFilters)
- val newSource = plan.copy(partitionFilters = newPartitionFilters)
- if (plan.logicalLink.nonEmpty) {
- newSource.setLogicalLink(plan.logicalLink.get)
- }
- TransformHints.tag(newSource, TransformHints.getHint(plan))
- newSource
- case plan: BatchScanExec =>
- val newPartitionFilters: Seq[Expression] = plan.scan match {
- case scan: FileScan =>
- ExpressionConverter.transformDynamicPruningExpr(scan.partitionFilters)
- case _ =>
- ExpressionConverter.transformDynamicPruningExpr(plan.runtimeFilters)
- }
- val newSource = plan.copy(runtimeFilters = newPartitionFilters)
- if (plan.logicalLink.nonEmpty) {
- newSource.setLogicalLink(plan.logicalLink.get)
- }
- TransformHints.tag(newSource, TransformHints.getHint(plan))
- newSource
- case plan if HiveTableScanExecTransformer.isHiveTableScan(plan) =>
- val newPartitionFilters: Seq[Expression] =
- ExpressionConverter.transformDynamicPruningExpr(
- HiveTableScanExecTransformer.getPartitionFilters(plan))
- val newSource = HiveTableScanExecTransformer.copyWith(plan, newPartitionFilters)
- if (plan.logicalLink.nonEmpty) {
- newSource.setLogicalLink(plan.logicalLink.get)
- }
- TransformHints.tag(newSource, TransformHints.getHint(plan))
- newSource
- case other =>
- throw new UnsupportedOperationException(s"${other.getClass.toString} is not supported.")
- }
-
- /**
- * Apply scan transformer for file source and batch source,
- * 1. create new filter and scan transformer, 2. validate, tag new scan as unsupported if
- * failed, 3. return new source.
- */
- private def applyScanTransformer(plan: SparkPlan): SparkPlan = plan match {
- case plan: FileSourceScanExec =>
- val transformer = ScanTransformerFactory.createFileSourceScanTransformer(plan)
- val validationResult = transformer.doValidate()
- if (validationResult.isValid) {
- logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
- transformer
- } else {
- logDebug(s"Columnar Processing for ${plan.getClass} is currently unsupported.")
- val newSource = plan.copy(partitionFilters = transformer.getPartitionFilters())
- TransformHints.tagNotTransformable(newSource, validationResult.reason.get)
- newSource
- }
- case plan: BatchScanExec =>
- ScanTransformerFactory.createBatchScanTransformer(plan)
-
- case plan if HiveTableScanExecTransformer.isHiveTableScan(plan) =>
- // TODO: Add DynamicPartitionPruningHiveScanSuite.scala
- val newPartitionFilters: Seq[Expression] =
- ExpressionConverter.transformDynamicPruningExpr(
- HiveTableScanExecTransformer.getPartitionFilters(plan))
- val hiveTableScanExecTransformer =
- BackendsApiManager.getSparkPlanExecApiInstance.genHiveTableScanExecTransformer(plan)
- val validateResult = hiveTableScanExecTransformer.doValidate()
- if (validateResult.isValid) {
- logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
- return hiveTableScanExecTransformer
- }
- logDebug(s"Columnar Processing for ${plan.getClass} is currently unsupported.")
- val newSource = HiveTableScanExecTransformer.copyWith(plan, newPartitionFilters)
- TransformHints.tagNotTransformable(newSource, validateResult.reason.get)
- newSource
- case other =>
- throw new GlutenNotSupportException(s"${other.getClass.toString} is not supported.")
- }
+ def apply(): TransformPreOverrides = {
+ TransformPreOverrides(
+ List(ImplementFilter()),
+ List(
+ ImplementOthers(),
+ ImplementAggregate(),
+ ImplementExchange(),
+ ImplementJoin()
+ )
+ )
}
}
// This rule will conduct the conversion from Spark plan to the plan transformer.
- case class TransformPreOverrides() extends Rule[SparkPlan] with LogLevelUtil {
- import TransformPreOverrides._
-
- private val topdownRules = List(
- FilterTransformRule()
- )
- private val bottomupRules = List(
- RegularTransformRule(),
- AggregationTransformRule(),
- ExchangeTransformRule(),
- JoinTransformRule()
- )
-
+ case class TransformPreOverrides(
+ topDownRules: Seq[ImplementSingleNode],
+ bottomUpRules: Seq[ImplementSingleNode])
+ extends Rule[SparkPlan]
+ with LogLevelUtil {
@transient private val planChangeLogger = new PlanChangeLogger[SparkPlan]()
def apply(plan: SparkPlan): SparkPlan = {
- val plan0 = topdownRules.foldLeft(plan)((p, rule) => p.transformDown { case p => rule(p) })
- val plan1 = bottomupRules.foldLeft(plan0)((p, rule) => p.transformUp { case p => rule(p) })
+ val plan0 =
+ topDownRules.foldLeft(plan)((p, rule) => p.transformDown { case p => rule.impl(p) })
+ val plan1 =
+ bottomUpRules.foldLeft(plan0)((p, rule) => p.transformUp { case p => rule.impl(p) })
planChangeLogger.logRule(ruleName, plan, plan1)
plan1
}
diff --git a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/transform/ImplementSingleNode.scala b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/transform/ImplementSingleNode.scala
new file mode 100644
index 000000000000..5598014573c7
--- /dev/null
+++ b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/transform/ImplementSingleNode.scala
@@ -0,0 +1,517 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.extension.columnar.transform
+
+import io.glutenproject.GlutenConfig
+import io.glutenproject.backendsapi.BackendsApiManager
+import io.glutenproject.exception.GlutenNotSupportException
+import io.glutenproject.execution._
+import io.glutenproject.expression.ExpressionConverter
+import io.glutenproject.extension.GlutenPlan
+import io.glutenproject.extension.columnar.TransformHints
+import io.glutenproject.sql.shims.SparkShimLoader
+import io.glutenproject.utils.{LogLevelUtil, PlanUtil}
+
+import org.apache.spark.api.python.EvalPythonExecTransformer
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide}
+import org.apache.spark.sql.catalyst.plans.{LeftOuter, LeftSemi, RightOuter}
+import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
+import org.apache.spark.sql.execution.datasources.WriteFilesExec
+import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, FileScan}
+import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec}
+import org.apache.spark.sql.execution.joins._
+import org.apache.spark.sql.execution.python.EvalPythonExec
+import org.apache.spark.sql.execution.window.WindowExec
+import org.apache.spark.sql.hive.HiveTableScanExecTransformer
+
+sealed trait ImplementSingleNode extends Logging {
+ def impl(plan: SparkPlan): SparkPlan
+}
+
+// Aggregation transformation.
+case class ImplementAggregate() extends ImplementSingleNode with LogLevelUtil {
+ override def impl(plan: SparkPlan): SparkPlan = plan match {
+ case plan if TransformHints.isNotTransformable(plan) =>
+ plan
+ case agg: HashAggregateExec =>
+ genHashAggregateExec(agg)
+ case other => other
+ }
+
+ /**
+ * Generate a plan for hash aggregation.
+ *
+ * @param plan
+ * : the original Spark plan.
+ * @return
+ * the actually used plan for execution.
+ */
+ private def genHashAggregateExec(plan: HashAggregateExec): SparkPlan = {
+ if (TransformHints.isNotTransformable(plan)) {
+ return plan
+ }
+
+ val aggChild = plan.child
+
+ def transformHashAggregate(): GlutenPlan = {
+ BackendsApiManager.getSparkPlanExecApiInstance
+ .genHashAggregateExecTransformer(
+ plan.requiredChildDistributionExpressions,
+ plan.groupingExpressions,
+ plan.aggregateExpressions,
+ plan.aggregateAttributes,
+ plan.initialInputBufferOffset,
+ plan.resultExpressions,
+ aggChild
+ )
+ }
+
+ // If child's output is empty, fallback or offload both the child and aggregation.
+ if (plan.child.output.isEmpty && BackendsApiManager.getSettings.fallbackAggregateWithChild()) {
+ aggChild match {
+ case _: TransformSupport =>
+ // If the child is transformable, transform aggregation as well.
+ logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
+ transformHashAggregate()
+ case p: SparkPlan if PlanUtil.isGlutenTableCache(p) =>
+ transformHashAggregate()
+ case _ =>
+ // If the child is not transformable, do not transform the agg.
+ TransformHints.tagNotTransformable(plan, "child output schema is empty")
+ plan
+ }
+ } else {
+ logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
+ transformHashAggregate()
+ }
+ }
+}
+
+// Exchange transformation.
+case class ImplementExchange() extends ImplementSingleNode with LogLevelUtil {
+ override def impl(plan: SparkPlan): SparkPlan = plan match {
+ case plan if TransformHints.isNotTransformable(plan) =>
+ plan
+ case plan: ShuffleExchangeExec =>
+ logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
+ val child = plan.child
+ if (
+ (child.supportsColumnar || GlutenConfig.getConf.enablePreferColumnar) &&
+ BackendsApiManager.getSettings.supportColumnarShuffleExec()
+ ) {
+ BackendsApiManager.getSparkPlanExecApiInstance.genColumnarShuffleExchange(plan, child)
+ } else {
+ plan.withNewChildren(Seq(child))
+ }
+ case plan: BroadcastExchangeExec =>
+ val child = plan.child
+ logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
+ ColumnarBroadcastExchangeExec(plan.mode, child)
+ case other => other
+ }
+}
+
+// Join transformation.
+case class ImplementJoin() extends ImplementSingleNode with LogLevelUtil {
+
+ /**
+ * Get the build side supported by the execution of vanilla Spark.
+ *
+ * @param plan
+ * : shuffled hash join plan
+ * @return
+ * the supported build side
+ */
+ private def getSparkSupportedBuildSide(plan: ShuffledHashJoinExec): BuildSide = {
+ plan.joinType match {
+ case LeftOuter | LeftSemi => BuildRight
+ case RightOuter => BuildLeft
+ case _ => plan.buildSide
+ }
+ }
+
+ override def impl(plan: SparkPlan): SparkPlan = {
+ if (TransformHints.isNotTransformable(plan)) {
+ logDebug(s"Columnar Processing for ${plan.getClass} is under row guard.")
+ plan match {
+ case shj: ShuffledHashJoinExec =>
+ if (BackendsApiManager.getSettings.recreateJoinExecOnFallback()) {
+ // Because we manually removed the build side limitation for LeftOuter, LeftSemi and
+ // RightOuter, need to change the build side back if this join fallback into vanilla
+ // Spark for execution.
+ return ShuffledHashJoinExec(
+ shj.leftKeys,
+ shj.rightKeys,
+ shj.joinType,
+ getSparkSupportedBuildSide(shj),
+ shj.condition,
+ shj.left,
+ shj.right,
+ shj.isSkewJoin
+ )
+ } else {
+ return shj
+ }
+ case p =>
+ return p
+ }
+ }
+ plan match {
+ case plan: ShuffledHashJoinExec =>
+ val left = plan.left
+ val right = plan.right
+ logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
+ BackendsApiManager.getSparkPlanExecApiInstance
+ .genShuffledHashJoinExecTransformer(
+ plan.leftKeys,
+ plan.rightKeys,
+ plan.joinType,
+ plan.buildSide,
+ plan.condition,
+ left,
+ right,
+ plan.isSkewJoin)
+ case plan: SortMergeJoinExec =>
+ val left = plan.left
+ val right = plan.right
+ logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
+ SortMergeJoinExecTransformer(
+ plan.leftKeys,
+ plan.rightKeys,
+ plan.joinType,
+ plan.condition,
+ left,
+ right,
+ plan.isSkewJoin)
+ case plan: BroadcastHashJoinExec =>
+ val left = plan.left
+ val right = plan.right
+ BackendsApiManager.getSparkPlanExecApiInstance
+ .genBroadcastHashJoinExecTransformer(
+ plan.leftKeys,
+ plan.rightKeys,
+ plan.joinType,
+ plan.buildSide,
+ plan.condition,
+ left,
+ right,
+ isNullAwareAntiJoin = plan.isNullAwareAntiJoin)
+ case plan: CartesianProductExec =>
+ val left = plan.left
+ val right = plan.right
+ BackendsApiManager.getSparkPlanExecApiInstance
+ .genCartesianProductExecTransformer(left, right, plan.condition)
+ case plan: BroadcastNestedLoopJoinExec =>
+ val left = plan.left
+ val right = plan.right
+ BackendsApiManager.getSparkPlanExecApiInstance
+ .genBroadcastNestedLoopJoinExecTransformer(
+ left,
+ right,
+ plan.buildSide,
+ plan.joinType,
+ plan.condition)
+ case other => other
+ }
+ }
+
+}
+
+// Filter transformation.
+case class ImplementFilter() extends ImplementSingleNode with LogLevelUtil {
+ import ImplementOthers._
+ private val replace = new ReplaceSingleNode()
+
+ override def impl(plan: SparkPlan): SparkPlan = plan match {
+ case filter: FilterExec =>
+ genFilterExec(filter)
+ case other => other
+ }
+
+ /**
+ * Generate a plan for filter.
+ *
+ * @param plan
+ * : the original Spark plan.
+ * @return
+ * the actually used plan for execution.
+ */
+ private def genFilterExec(plan: FilterExec): SparkPlan = {
+ if (TransformHints.isNotTransformable(plan)) {
+ return plan
+ }
+
+ // FIXME: Filter push-down should be better done by Vanilla Spark's planner or by
+ // a individual rule.
+ val scan = plan.child
+ // Push down the left conditions in Filter into FileSourceScan.
+ val newChild: SparkPlan = scan match {
+ case _: FileSourceScanExec | _: BatchScanExec =>
+ if (TransformHints.isTransformable(scan)) {
+ val newScan = FilterHandler.applyFilterPushdownToScan(plan)
+ newScan match {
+ case ts: TransformSupport if ts.doValidate().isValid => ts
+ // TODO remove the call
+ case _ => replace.doReplace(scan)
+ }
+ } else {
+ replace.doReplace(scan)
+ }
+ case _ => replace.doReplace(plan.child)
+ }
+ logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
+ BackendsApiManager.getSparkPlanExecApiInstance
+ .genFilterExecTransformer(plan.condition, newChild)
+ }
+}
+
+// Other transformations.
+case class ImplementOthers() extends ImplementSingleNode with LogLevelUtil {
+ import ImplementOthers._
+ private val replace = new ReplaceSingleNode()
+
+ override def impl(plan: SparkPlan): SparkPlan = replace.doReplace(plan)
+}
+
+object ImplementOthers {
+ // Utility to replace single node within transformed Gluten node.
+ // Children will be preserved as they are as children of the output node.
+ //
+ // Do not look-up on children on the input node in this rule. Otherwise
+ // it may break ACBO which would group all the possible input nodes to
+ // search for validate candidates.
+ class ReplaceSingleNode() extends LogLevelUtil with Logging {
+
+ def doReplace(p: SparkPlan): SparkPlan = {
+ val plan = p
+ if (TransformHints.isNotTransformable(plan)) {
+ logDebug(s"Columnar Processing for ${plan.getClass} is under row guard.")
+ plan match {
+ case plan: BatchScanExec =>
+ return applyScanNotTransformable(plan)
+ case plan: FileSourceScanExec =>
+ return applyScanNotTransformable(plan)
+ case plan if HiveTableScanExecTransformer.isHiveTableScan(plan) =>
+ return applyScanNotTransformable(plan)
+ case p =>
+ return p
+ }
+ }
+ plan match {
+ case plan: BatchScanExec =>
+ applyScanTransformer(plan)
+ case plan: FileSourceScanExec =>
+ applyScanTransformer(plan)
+ case plan if HiveTableScanExecTransformer.isHiveTableScan(plan) =>
+ applyScanTransformer(plan)
+ case plan: CoalesceExec =>
+ logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
+ CoalesceExecTransformer(plan.numPartitions, plan.child)
+ case plan: ProjectExec =>
+ val columnarChild = plan.child
+ logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
+ ProjectExecTransformer(plan.projectList, columnarChild)
+ case plan: SortAggregateExec =>
+ logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
+ BackendsApiManager.getSparkPlanExecApiInstance
+ .genHashAggregateExecTransformer(
+ plan.requiredChildDistributionExpressions,
+ plan.groupingExpressions,
+ plan.aggregateExpressions,
+ plan.aggregateAttributes,
+ plan.initialInputBufferOffset,
+ plan.resultExpressions,
+ plan.child match {
+ case sort: SortExecTransformer if !sort.global =>
+ sort.child
+ case sort: SortExec if !sort.global =>
+ sort.child
+ case _ => plan.child
+ }
+ )
+ case plan: ObjectHashAggregateExec =>
+ val child = plan.child
+ logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
+ BackendsApiManager.getSparkPlanExecApiInstance
+ .genHashAggregateExecTransformer(
+ plan.requiredChildDistributionExpressions,
+ plan.groupingExpressions,
+ plan.aggregateExpressions,
+ plan.aggregateAttributes,
+ plan.initialInputBufferOffset,
+ plan.resultExpressions,
+ child
+ )
+ case plan: UnionExec =>
+ val children = plan.children
+ logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
+ ColumnarUnionExec(children)
+ case plan: ExpandExec =>
+ val child = plan.child
+ logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
+ ExpandExecTransformer(plan.projections, plan.output, child)
+ case plan: WriteFilesExec =>
+ val child = plan.child
+ logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
+ val writeTransformer = WriteFilesExecTransformer(
+ child,
+ plan.fileFormat,
+ plan.partitionColumns,
+ plan.bucketSpec,
+ plan.options,
+ plan.staticPartitions)
+ BackendsApiManager.getSparkPlanExecApiInstance.createColumnarWriteFilesExec(
+ writeTransformer,
+ plan.fileFormat,
+ plan.partitionColumns,
+ plan.bucketSpec,
+ plan.options,
+ plan.staticPartitions
+ )
+ case plan: SortExec =>
+ val child = plan.child
+ logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
+ SortExecTransformer(plan.sortOrder, plan.global, child, plan.testSpillFrequency)
+ case plan: TakeOrderedAndProjectExec =>
+ logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
+ val child = plan.child
+ val (limit, offset) = SparkShimLoader.getSparkShims.getLimitAndOffsetFromTopK(plan)
+ TakeOrderedAndProjectExecTransformer(
+ limit,
+ plan.sortOrder,
+ plan.projectList,
+ child,
+ offset)
+ case plan: WindowExec =>
+ WindowExecTransformer(
+ plan.windowExpression,
+ plan.partitionSpec,
+ plan.orderSpec,
+ plan.child)
+ case plan: GlobalLimitExec =>
+ logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
+ val child = plan.child
+ val (limit, offset) =
+ SparkShimLoader.getSparkShims.getLimitAndOffsetFromGlobalLimit(plan)
+ LimitTransformer(child, offset, limit)
+ case plan: LocalLimitExec =>
+ logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
+ val child = plan.child
+ LimitTransformer(child, 0L, plan.limit)
+ case plan: GenerateExec =>
+ logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
+ val child = plan.child
+ BackendsApiManager.getSparkPlanExecApiInstance.genGenerateTransformer(
+ plan.generator,
+ plan.requiredChildOutput,
+ plan.outer,
+ plan.generatorOutput,
+ child)
+ case plan: EvalPythonExec =>
+ logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
+ val child = plan.child
+ EvalPythonExecTransformer(plan.udfs, plan.resultAttrs, child)
+ case p if !p.isInstanceOf[GlutenPlan] =>
+ logDebug(s"Transformation for ${p.getClass} is currently not supported.")
+ val children = plan.children
+ p.withNewChildren(children)
+ case other => other
+ }
+ }
+
+ private def applyScanNotTransformable(plan: SparkPlan): SparkPlan = plan match {
+ case plan: FileSourceScanExec =>
+ val newPartitionFilters =
+ ExpressionConverter.transformDynamicPruningExpr(plan.partitionFilters)
+ val newSource = plan.copy(partitionFilters = newPartitionFilters)
+ if (plan.logicalLink.nonEmpty) {
+ newSource.setLogicalLink(plan.logicalLink.get)
+ }
+ TransformHints.tag(newSource, TransformHints.getHint(plan))
+ newSource
+ case plan: BatchScanExec =>
+ val newPartitionFilters: Seq[Expression] = plan.scan match {
+ case scan: FileScan =>
+ ExpressionConverter.transformDynamicPruningExpr(scan.partitionFilters)
+ case _ =>
+ ExpressionConverter.transformDynamicPruningExpr(plan.runtimeFilters)
+ }
+ val newSource = plan.copy(runtimeFilters = newPartitionFilters)
+ if (plan.logicalLink.nonEmpty) {
+ newSource.setLogicalLink(plan.logicalLink.get)
+ }
+ TransformHints.tag(newSource, TransformHints.getHint(plan))
+ newSource
+ case plan if HiveTableScanExecTransformer.isHiveTableScan(plan) =>
+ val newPartitionFilters: Seq[Expression] =
+ ExpressionConverter.transformDynamicPruningExpr(
+ HiveTableScanExecTransformer.getPartitionFilters(plan))
+ val newSource = HiveTableScanExecTransformer.copyWith(plan, newPartitionFilters)
+ if (plan.logicalLink.nonEmpty) {
+ newSource.setLogicalLink(plan.logicalLink.get)
+ }
+ TransformHints.tag(newSource, TransformHints.getHint(plan))
+ newSource
+ case other =>
+ throw new UnsupportedOperationException(s"${other.getClass.toString} is not supported.")
+ }
+
+ /**
+ * Apply scan transformer for file source and batch source,
+ * 1. create new filter and scan transformer, 2. validate, tag new scan as unsupported if
+ * failed, 3. return new source.
+ */
+ private def applyScanTransformer(plan: SparkPlan): SparkPlan = plan match {
+ case plan: FileSourceScanExec =>
+ val transformer = ScanTransformerFactory.createFileSourceScanTransformer(plan)
+ val validationResult = transformer.doValidate()
+ if (validationResult.isValid) {
+ logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
+ transformer
+ } else {
+ logDebug(s"Columnar Processing for ${plan.getClass} is currently unsupported.")
+ val newSource = plan.copy(partitionFilters = transformer.getPartitionFilters())
+ TransformHints.tagNotTransformable(newSource, validationResult.reason.get)
+ newSource
+ }
+ case plan: BatchScanExec =>
+ ScanTransformerFactory.createBatchScanTransformer(plan)
+
+ case plan if HiveTableScanExecTransformer.isHiveTableScan(plan) =>
+ // TODO: Add DynamicPartitionPruningHiveScanSuite.scala
+ val newPartitionFilters: Seq[Expression] =
+ ExpressionConverter.transformDynamicPruningExpr(
+ HiveTableScanExecTransformer.getPartitionFilters(plan))
+ val hiveTableScanExecTransformer =
+ BackendsApiManager.getSparkPlanExecApiInstance.genHiveTableScanExecTransformer(plan)
+ val validateResult = hiveTableScanExecTransformer.doValidate()
+ if (validateResult.isValid) {
+ logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
+ return hiveTableScanExecTransformer
+ }
+ logDebug(s"Columnar Processing for ${plan.getClass} is currently unsupported.")
+ val newSource = HiveTableScanExecTransformer.copyWith(plan, newPartitionFilters)
+ TransformHints.tagNotTransformable(newSource, validateResult.reason.get)
+ newSource
+ case other =>
+ throw new GlutenNotSupportException(s"${other.getClass.toString} is not supported.")
+ }
+ }
+}
diff --git a/gluten-core/src/main/scala/io/glutenproject/planner/GlutenOptimization.scala b/gluten-core/src/main/scala/io/glutenproject/planner/GlutenOptimization.scala
new file mode 100644
index 000000000000..e1540e553d60
--- /dev/null
+++ b/gluten-core/src/main/scala/io/glutenproject/planner/GlutenOptimization.scala
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.planner
+
+import io.glutenproject.cbo.{CboExplain, Optimization}
+import io.glutenproject.cbo.rule.CboRule
+import io.glutenproject.planner.cost.GlutenCostModel
+import io.glutenproject.planner.plan.GlutenPlanModel
+import io.glutenproject.planner.property.GlutenPropertyModel
+import io.glutenproject.planner.rule.GlutenRules
+
+import org.apache.spark.sql.execution.SparkPlan
+
+object GlutenOptimization {
+ private object GlutenExplain extends CboExplain[SparkPlan] {
+ override def describeNode(node: SparkPlan): String = node.nodeName
+ }
+
+ def apply(): Optimization[SparkPlan] = {
+ Optimization[SparkPlan](
+ GlutenCostModel(),
+ GlutenPlanModel(),
+ GlutenPropertyModel(),
+ GlutenExplain,
+ CboRule.Factory.reuse(GlutenRules()))
+ }
+
+ def apply(rules: Seq[CboRule[SparkPlan]]): Optimization[SparkPlan] = {
+ Optimization[SparkPlan](
+ GlutenCostModel(),
+ GlutenPlanModel(),
+ GlutenPropertyModel(),
+ GlutenExplain,
+ CboRule.Factory.reuse(rules))
+ }
+}
diff --git a/gluten-core/src/main/scala/io/glutenproject/planner/cost/GlutenCost.scala b/gluten-core/src/main/scala/io/glutenproject/planner/cost/GlutenCost.scala
new file mode 100644
index 000000000000..2d85352031e4
--- /dev/null
+++ b/gluten-core/src/main/scala/io/glutenproject/planner/cost/GlutenCost.scala
@@ -0,0 +1,21 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.planner.cost
+
+import io.glutenproject.cbo.Cost
+
+case class GlutenCost(value: Long) extends Cost
diff --git a/gluten-core/src/main/scala/io/glutenproject/planner/cost/GlutenCostModel.scala b/gluten-core/src/main/scala/io/glutenproject/planner/cost/GlutenCostModel.scala
new file mode 100644
index 000000000000..013d4ed3ca6b
--- /dev/null
+++ b/gluten-core/src/main/scala/io/glutenproject/planner/cost/GlutenCostModel.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.planner.cost
+
+import io.glutenproject.cbo.{Cost, CostModel}
+import io.glutenproject.extension.columnar.ColumnarTransitions
+import io.glutenproject.planner.plan.GlutenPlanModel.GroupLeafExec
+import io.glutenproject.utils.PlanUtil
+
+import org.apache.spark.sql.execution.{ColumnarToRowExec, RowToColumnarExec, SparkPlan}
+
+class GlutenCostModel {}
+
+object GlutenCostModel {
+ def apply(): CostModel[SparkPlan] = {
+ RoughCostModel
+ }
+
+ private object RoughCostModel extends CostModel[SparkPlan] {
+ override def costOf(node: SparkPlan): GlutenCost = node match {
+ case _: GroupLeafExec => throw new IllegalStateException()
+ case _ => GlutenCost(longCostOf(node))
+ }
+
+ private def longCostOf(node: SparkPlan): Long = node match {
+ case n =>
+ val selfCost = selfLongCostOf(n)
+
+ // Sum with ceil to avoid overflow.
+ def safeSum(a: Long, b: Long): Long = {
+ assert(a >= 0)
+ assert(b >= 0)
+ val sum = a + b
+ if (sum < a || sum < b) Long.MaxValue else sum
+ }
+
+ (n.children.map(longCostOf).toList :+ selfCost).reduce(safeSum)
+ }
+
+ // A very rough estimation as of now.
+ private def selfLongCostOf(node: SparkPlan): Long = node match {
+ case ColumnarToRowExec(child) => 3L
+ case RowToColumnarExec(child) => 3L
+ case ColumnarTransitions.ColumnarToRowLike(child) => 3L
+ case ColumnarTransitions.RowToColumnarLike(child) => 3L
+ case p if PlanUtil.isGlutenColumnarOp(p) => 2L
+ case p if PlanUtil.isVanillaColumnarOp(p) => 3L
+ // Other row ops. Usually a vanilla row op.
+ case _ => 5L
+ }
+
+ override def costComparator(): Ordering[Cost] = Ordering.Long.on {
+ case GlutenCost(value) => value
+ case _ => throw new IllegalStateException("Unexpected cost type")
+ }
+
+ override def makeInfCost(): Cost = GlutenCost(Long.MaxValue)
+ }
+}
diff --git a/gluten-core/src/main/scala/io/glutenproject/planner/plan/GlutenPlanModel.scala b/gluten-core/src/main/scala/io/glutenproject/planner/plan/GlutenPlanModel.scala
new file mode 100644
index 000000000000..3e9b4777ae86
--- /dev/null
+++ b/gluten-core/src/main/scala/io/glutenproject/planner/plan/GlutenPlanModel.scala
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.planner.plan
+
+import io.glutenproject.cbo.PlanModel
+import io.glutenproject.cbo.property.PropertySet
+import io.glutenproject.planner.property.GlutenProperties
+import io.glutenproject.planner.property.GlutenProperties.Conventions
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan}
+
+import java.util.Objects
+
+object GlutenPlanModel {
+ def apply(): PlanModel[SparkPlan] = {
+ PlanModelImpl
+ }
+
+ case class GroupLeafExec(groupId: Int, propertySet: PropertySet[SparkPlan]) extends LeafExecNode {
+ override protected def doExecute(): RDD[InternalRow] = throw new IllegalStateException()
+ override def output: Seq[Attribute] = propertySet.get(GlutenProperties.SCHEMA_DEF).output
+ override def supportsColumnar: Boolean =
+ propertySet.get(GlutenProperties.CONVENTION_DEF) match {
+ case Conventions.ROW_BASED => false
+ case Conventions.VANILLA_COLUMNAR => true
+ case Conventions.GLUTEN_COLUMNAR => true
+ case Conventions.ANY => true
+ }
+ }
+
+ private object PlanModelImpl extends PlanModel[SparkPlan] {
+ override def childrenOf(node: SparkPlan): Seq[SparkPlan] = node.children
+
+ override def withNewChildren(node: SparkPlan, children: Seq[SparkPlan]): SparkPlan = {
+ node.withNewChildren(children)
+ }
+
+ override def hashCode(node: SparkPlan): Int = Objects.hashCode(node)
+
+ override def equals(one: SparkPlan, other: SparkPlan): Boolean = Objects.equals(one, other)
+
+ override def newGroupLeaf(groupId: Int, propSet: PropertySet[SparkPlan]): SparkPlan =
+ GroupLeafExec(groupId, propSet)
+
+ override def isGroupLeaf(node: SparkPlan): Boolean = node match {
+ case _: GroupLeafExec => true
+ case _ => false
+ }
+
+ override def getGroupId(node: SparkPlan): Int = node match {
+ case gl: GroupLeafExec => gl.groupId
+ case _ => throw new IllegalStateException()
+ }
+ }
+}
diff --git a/gluten-core/src/main/scala/io/glutenproject/planner/property/GlutenPropertyModel.scala b/gluten-core/src/main/scala/io/glutenproject/planner/property/GlutenPropertyModel.scala
new file mode 100644
index 000000000000..3b88b89bcbc0
--- /dev/null
+++ b/gluten-core/src/main/scala/io/glutenproject/planner/property/GlutenPropertyModel.scala
@@ -0,0 +1,164 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.planner.property
+
+import io.glutenproject.backendsapi.BackendsApiManager
+import io.glutenproject.cbo._
+import io.glutenproject.cbo.rule.{CboRule, Shape, Shapes}
+import io.glutenproject.extension.columnar.ColumnarTransitions
+import io.glutenproject.planner.plan.GlutenPlanModel.GroupLeafExec
+import io.glutenproject.planner.property.GlutenProperties.{Convention, CONVENTION_DEF, ConventionEnforcerRule, SCHEMA_DEF}
+import io.glutenproject.sql.shims.SparkShimLoader
+import io.glutenproject.utils.PlanUtil
+
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.execution._
+
+object GlutenProperties {
+ val SCHEMA_DEF: PropertyDef[SparkPlan, Schema] = new PropertyDef[SparkPlan, Schema] {
+ override def getProperty(plan: SparkPlan): Schema = plan match {
+ case _: GroupLeafExec => throw new IllegalStateException()
+ case _ => Schema(plan.output)
+ }
+ override def getChildrenConstraints(
+ constraint: Property[SparkPlan],
+ plan: SparkPlan): Seq[Schema] = {
+ plan.children.map(c => Schema(c.output))
+ }
+ }
+
+ val CONVENTION_DEF: PropertyDef[SparkPlan, Convention] = new PropertyDef[SparkPlan, Convention] {
+ // TODO: Should the convention-transparent ops (e.g., aqe shuffle read) support
+ // convention-propagation. Probably need to refactor getChildrenPropertyRequirements.
+ override def getProperty(plan: SparkPlan): Convention = plan match {
+ case _: GroupLeafExec => throw new IllegalStateException()
+ case ColumnarToRowExec(child) => Conventions.ROW_BASED
+ case RowToColumnarExec(child) => Conventions.VANILLA_COLUMNAR
+ case ColumnarTransitions.ColumnarToRowLike(child) => Conventions.ROW_BASED
+ case ColumnarTransitions.RowToColumnarLike(child) => Conventions.GLUTEN_COLUMNAR
+ case p if PlanUtil.outputNativeColumnarData(p) => Conventions.GLUTEN_COLUMNAR
+ case p if PlanUtil.isVanillaColumnarOp(p) => Conventions.VANILLA_COLUMNAR
+ case p if SparkShimLoader.getSparkShims.supportsRowBased(p) => Conventions.ROW_BASED
+ case _ => throw new IllegalStateException()
+ }
+
+ override def getChildrenConstraints(
+ constraint: Property[SparkPlan],
+ plan: SparkPlan): Seq[Convention] = plan match {
+ case ColumnarToRowExec(child) => Seq(Conventions.VANILLA_COLUMNAR)
+ case ColumnarTransitions.ColumnarToRowLike(child) => Seq(Conventions.GLUTEN_COLUMNAR)
+ case ColumnarTransitions.RowToColumnarLike(child) => Seq(Conventions.ROW_BASED)
+ case _ =>
+ val conv = getProperty(plan)
+ plan.children.map(_ => conv)
+ }
+ }
+
+ case class ConventionEnforcerRule(reqConv: Convention) extends CboRule[SparkPlan] {
+ override def shift(node: SparkPlan): Iterable[SparkPlan] = {
+ val conv = CONVENTION_DEF.getProperty(node)
+ if (conv == reqConv) {
+ return List.empty
+ }
+ (conv, reqConv) match {
+ case (Conventions.VANILLA_COLUMNAR, Conventions.ROW_BASED) =>
+ List(ColumnarToRowExec(node))
+ case (Conventions.ROW_BASED, Conventions.VANILLA_COLUMNAR) =>
+ List(RowToColumnarExec(node))
+ case (Conventions.GLUTEN_COLUMNAR, Conventions.ROW_BASED) =>
+ List(BackendsApiManager.getSparkPlanExecApiInstance.genColumnarToRowExec(node))
+ case (Conventions.ROW_BASED, Conventions.GLUTEN_COLUMNAR) =>
+ val attempt = BackendsApiManager.getSparkPlanExecApiInstance.genRowToColumnarExec(node)
+ if (attempt.doValidate().isValid) {
+ List(attempt)
+ } else {
+ List.empty
+ }
+ case (Conventions.VANILLA_COLUMNAR, Conventions.GLUTEN_COLUMNAR) =>
+ List(
+ BackendsApiManager.getSparkPlanExecApiInstance.genRowToColumnarExec(
+ ColumnarToRowExec(node)))
+ case (Conventions.GLUTEN_COLUMNAR, Conventions.VANILLA_COLUMNAR) =>
+ List(
+ RowToColumnarExec(
+ BackendsApiManager.getSparkPlanExecApiInstance.genColumnarToRowExec(node)))
+ case _ => List.empty
+ }
+ }
+
+ override def shape(): Shape[SparkPlan] = Shapes.fixedHeight(1)
+ }
+
+ case class Schema(output: Seq[Attribute]) extends Property[SparkPlan] {
+ override def satisfies(other: Property[SparkPlan]): Boolean = other match {
+ case Schemas.ANY => true
+ case Schema(otherOutput) => output == otherOutput
+ case _ => throw new IllegalStateException()
+ }
+
+ override def definition(): PropertyDef[SparkPlan, _ <: Property[SparkPlan]] = {
+ SCHEMA_DEF
+ }
+ }
+
+ object Schemas {
+ val ANY: Property[SparkPlan] = Schema(List())
+ }
+
+ sealed trait Convention extends Property[SparkPlan] {
+ override def definition(): PropertyDef[SparkPlan, _ <: Property[SparkPlan]] = {
+ CONVENTION_DEF
+ }
+
+ override def satisfies(other: Property[SparkPlan]): Boolean = other match {
+ case Conventions.ANY => true
+ case c: Convention => c == this
+ case _ => throw new IllegalStateException()
+ }
+ }
+
+ object Conventions {
+ // FIXME: Velox and CH should have different conventions?
+ case object ROW_BASED extends Convention
+ case object VANILLA_COLUMNAR extends Convention
+ case object GLUTEN_COLUMNAR extends Convention
+ case object ANY extends Convention
+ }
+}
+
+object GlutenPropertyModel {
+
+ def apply(): PropertyModel[SparkPlan] = {
+ PropertyModelImpl
+ }
+
+ private object PropertyModelImpl extends PropertyModel[SparkPlan] {
+ override def propertyDefs: Seq[PropertyDef[SparkPlan, _ <: Property[SparkPlan]]] =
+ Seq(SCHEMA_DEF, CONVENTION_DEF)
+
+ override def newEnforcerRuleFactory(
+ propertyDef: PropertyDef[SparkPlan, _ <: Property[SparkPlan]])
+ : EnforcerRuleFactory[SparkPlan] = (reqProp: Property[SparkPlan]) => {
+ propertyDef match {
+ case SCHEMA_DEF =>
+ Seq()
+ case CONVENTION_DEF =>
+ Seq(ConventionEnforcerRule(reqProp.asInstanceOf[Convention]))
+ }
+ }
+ }
+}
diff --git a/gluten-core/src/main/scala/io/glutenproject/planner/rule/GlutenRules.scala b/gluten-core/src/main/scala/io/glutenproject/planner/rule/GlutenRules.scala
new file mode 100644
index 000000000000..9b30f25378c2
--- /dev/null
+++ b/gluten-core/src/main/scala/io/glutenproject/planner/rule/GlutenRules.scala
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.planner.rule
+
+import io.glutenproject.cbo.rule.CboRule
+
+import org.apache.spark.sql.execution.SparkPlan
+
+object GlutenRules {
+ def apply(): Seq[CboRule[SparkPlan]] = {
+ List() // TODO
+ }
+}
diff --git a/pom.xml b/pom.xml
index afe86e56ad56..c606538e90cf 100644
--- a/pom.xml
+++ b/pom.xml
@@ -34,6 +34,7 @@
gluten-ui
package
shims
+ gluten-cbo
diff --git a/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala b/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala
index a30f70baac22..4c53c0e1c77b 100644
--- a/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala
+++ b/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala
@@ -41,6 +41,8 @@ class GlutenConfig(conf: SQLConf) extends Logging {
def enableGluten: Boolean = conf.getConf(GLUTEN_ENABLED)
+ def enableAdvancedCbo: Boolean = conf.getConf(ADVANCED_CBO_ENABLED)
+
// FIXME the option currently controls both JVM and native validation against a Substrait plan.
def enableNativeValidation: Boolean = conf.getConf(NATIVE_VALIDATION_ENABLED)
@@ -665,6 +667,19 @@ object GlutenConfig {
.booleanConf
.createWithDefault(GLUTEN_ENABLE_BY_DEFAULT)
+ val ADVANCED_CBO_ENABLED =
+ buildConf("spark.gluten.sql.advanced.cbo.enabled")
+ .doc(
+ "Experimental: Enables Gluten's advanced CBO features during physical planning. " +
+ "E.g, More efficient fallback strategy, etc. The option can be turned on and off " +
+ "individually despite vanilla Spark's CBO settings. Note, Gluten's query optimizer " +
+ "may still adopt a subset of its advanced CBO capabilities even this option " +
+ "is off. Enabling it would cause Gluten consider using CBO for optimization " +
+ "more aggressively. Note, this feature is still in development and may not bring " +
+ "performance profits.")
+ .booleanConf
+ .createWithDefault(false)
+
// FIXME the option currently controls both JVM and native validation against a Substrait plan.
val NATIVE_VALIDATION_ENABLED =
buildConf("spark.gluten.sql.enable.native.validation")
diff --git a/shims/common/src/main/scala/io/glutenproject/sql/shims/SparkShims.scala b/shims/common/src/main/scala/io/glutenproject/sql/shims/SparkShims.scala
index 0b390b97993c..e1fcb7b910b3 100644
--- a/shims/common/src/main/scala/io/glutenproject/sql/shims/SparkShims.scala
+++ b/shims/common/src/main/scala/io/glutenproject/sql/shims/SparkShims.scala
@@ -180,4 +180,7 @@ trait SparkShims {
def extractExpressionTimestampAddUnit(timestampAdd: Expression): Option[Seq[String]] =
Option.empty
+
+ def supportsRowBased(plan: SparkPlan): Boolean = !plan.supportsColumnar
+
}
diff --git a/shims/spark33/src/main/scala/io/glutenproject/sql/shims/spark33/Spark33Shims.scala b/shims/spark33/src/main/scala/io/glutenproject/sql/shims/spark33/Spark33Shims.scala
index 91372eb0c444..c0a5f2e68742 100644
--- a/shims/spark33/src/main/scala/io/glutenproject/sql/shims/spark33/Spark33Shims.scala
+++ b/shims/spark33/src/main/scala/io/glutenproject/sql/shims/spark33/Spark33Shims.scala
@@ -272,4 +272,6 @@ class Spark33Shims extends SparkShims {
case _ => Option.empty
}
}
+
+ override def supportsRowBased(plan: SparkPlan): Boolean = plan.supportsRowBased
}
diff --git a/shims/spark34/src/main/scala/io/glutenproject/sql/shims/spark34/Spark34Shims.scala b/shims/spark34/src/main/scala/io/glutenproject/sql/shims/spark34/Spark34Shims.scala
index c98def5daeff..13efff291f5e 100644
--- a/shims/spark34/src/main/scala/io/glutenproject/sql/shims/spark34/Spark34Shims.scala
+++ b/shims/spark34/src/main/scala/io/glutenproject/sql/shims/spark34/Spark34Shims.scala
@@ -344,4 +344,6 @@ class Spark34Shims extends SparkShims {
filteredPartitions.flatten
}
}
+
+ override def supportsRowBased(plan: SparkPlan): Boolean = plan.supportsRowBased
}
diff --git a/shims/spark35/src/main/scala/io/glutenproject/sql/shims/spark35/Spark35Shims.scala b/shims/spark35/src/main/scala/io/glutenproject/sql/shims/spark35/Spark35Shims.scala
index 6a6f3b2c8fd1..8bea501c05f7 100644
--- a/shims/spark35/src/main/scala/io/glutenproject/sql/shims/spark35/Spark35Shims.scala
+++ b/shims/spark35/src/main/scala/io/glutenproject/sql/shims/spark35/Spark35Shims.scala
@@ -309,4 +309,6 @@ class Spark35Shims extends SparkShims {
override def getCommonPartitionValues(batchScan: BatchScanExec): Option[Seq[(InternalRow, Int)]] =
null
+
+ override def supportsRowBased(plan: SparkPlan): Boolean = plan.supportsRowBased
}
diff --git a/tools/gluten-it/common/src/main/java/io/glutenproject/integration/tpc/TpcMixin.java b/tools/gluten-it/common/src/main/java/io/glutenproject/integration/tpc/TpcMixin.java
index 694f44049cfa..74c74ffb68d9 100644
--- a/tools/gluten-it/common/src/main/java/io/glutenproject/integration/tpc/TpcMixin.java
+++ b/tools/gluten-it/common/src/main/java/io/glutenproject/integration/tpc/TpcMixin.java
@@ -43,6 +43,7 @@ public class TpcMixin {
@CommandLine.Option(names = {"--log-level"}, description = "Set log level: 0 for DEBUG, 1 for INFO, 2 for WARN", defaultValue = "2")
private int logLevel;
+
@CommandLine.Option(names = {"--error-on-memleak"}, description = "Fail the test when memory leak is detected by Spark's memory manager", defaultValue = "false")
private boolean errorOnMemLeak;
@@ -152,7 +153,7 @@ public Integer runActions(Action[] actions) {
return 0;
}
- private Map mergeMapSafe(Map conf, Map extends K, ? extends V> other) {
+ private Map mergeMapSafe(Map conf, Map extends K, ? extends V> other) {
other.keySet().forEach(k -> {
if (conf.containsKey(k)) {
throw new IllegalArgumentException("Key already exists in conf: " + k);
diff --git a/tools/gluten-it/common/src/main/java/io/glutenproject/integration/tpc/command/Parameterized.java b/tools/gluten-it/common/src/main/java/io/glutenproject/integration/tpc/command/Parameterized.java
index afcd8ec82490..52632c778bf8 100644
--- a/tools/gluten-it/common/src/main/java/io/glutenproject/integration/tpc/command/Parameterized.java
+++ b/tools/gluten-it/common/src/main/java/io/glutenproject/integration/tpc/command/Parameterized.java
@@ -49,6 +49,9 @@ public class Parameterized implements Callable {
@CommandLine.Option(names = {"--queries"}, description = "Set a comma-separated list of query IDs to run, run all queries if not specified. Example: --queries=q1,q6", split = ",")
private String[] queries = new String[0];
+ @CommandLine.Option(names = {"--excluded-queries"}, description = "Set a comma-separated list of query IDs to exclude. Example: --exclude-queries=q1,q6", split = ",")
+ private String[] excludedQueries = new String[0];
+
@CommandLine.Option(names = {"--iterations"}, description = "How many iterations to run", defaultValue = "1")
private int iterations;
@@ -119,7 +122,7 @@ public Integer call() throws Exception {
)).collect(Collectors.toList())).asScala();
io.glutenproject.integration.tpc.action.Parameterized parameterized =
- new io.glutenproject.integration.tpc.action.Parameterized(dataGenMixin.getScale(), this.queries, iterations, warmupIterations, parsedDims, metrics);
+ new io.glutenproject.integration.tpc.action.Parameterized(dataGenMixin.getScale(), this.queries, excludedQueries, iterations, warmupIterations, parsedDims, metrics);
return mixin.runActions(ArrayUtils.addAll(dataGenMixin.makeActions(), parameterized));
}
}
diff --git a/tools/gluten-it/common/src/main/java/io/glutenproject/integration/tpc/command/Queries.java b/tools/gluten-it/common/src/main/java/io/glutenproject/integration/tpc/command/Queries.java
index ed675d4a882d..4645045a434b 100644
--- a/tools/gluten-it/common/src/main/java/io/glutenproject/integration/tpc/command/Queries.java
+++ b/tools/gluten-it/common/src/main/java/io/glutenproject/integration/tpc/command/Queries.java
@@ -35,6 +35,9 @@ public class Queries implements Callable {
@CommandLine.Option(names = {"--queries"}, description = "Set a comma-separated list of query IDs to run, run all queries if not specified. Example: --queries=q1,q6", split = ",")
private String[] queries = new String[0];
+ @CommandLine.Option(names = {"--excluded-queries"}, description = "Set a comma-separated list of query IDs to exclude. Example: --exclude-queries=q1,q6", split = ",")
+ private String[] excludedQueries = new String[0];
+
@CommandLine.Option(names = {"--explain"}, description = "Output explain result for queries", defaultValue = "false")
private boolean explain;
@@ -47,7 +50,7 @@ public class Queries implements Callable {
@Override
public Integer call() throws Exception {
io.glutenproject.integration.tpc.action.Queries queries =
- new io.glutenproject.integration.tpc.action.Queries(dataGenMixin.getScale(), this.queries, explain, iterations, randomKillTasks);
+ new io.glutenproject.integration.tpc.action.Queries(dataGenMixin.getScale(), this.queries, this.excludedQueries, explain, iterations, randomKillTasks);
return mixin.runActions(ArrayUtils.addAll(dataGenMixin.makeActions(), queries));
}
}
diff --git a/tools/gluten-it/common/src/main/java/io/glutenproject/integration/tpc/command/QueriesCompare.java b/tools/gluten-it/common/src/main/java/io/glutenproject/integration/tpc/command/QueriesCompare.java
index f0ee9ded3b32..d2e09bc7dc84 100644
--- a/tools/gluten-it/common/src/main/java/io/glutenproject/integration/tpc/command/QueriesCompare.java
+++ b/tools/gluten-it/common/src/main/java/io/glutenproject/integration/tpc/command/QueriesCompare.java
@@ -35,6 +35,9 @@ public class QueriesCompare implements Callable {
@CommandLine.Option(names = {"--queries"}, description = "Set a comma-separated list of query IDs to run, run all queries if not specified. Example: --queries=q1,q6", split = ",")
private String[] queries = new String[0];
+ @CommandLine.Option(names = {"--excluded-queries"}, description = "Set a comma-separated list of query IDs to exclude. Example: --exclude-queries=q1,q6", split = ",")
+ private String[] excludedQueries = new String[0];
+
@CommandLine.Option(names = {"--explain"}, description = "Output explain result for queries", defaultValue = "false")
private boolean explain;
@@ -44,7 +47,7 @@ public class QueriesCompare implements Callable {
@Override
public Integer call() throws Exception {
io.glutenproject.integration.tpc.action.QueriesCompare queriesCompare =
- new io.glutenproject.integration.tpc.action.QueriesCompare(dataGenMixin.getScale(), this.queries, explain, iterations);
+ new io.glutenproject.integration.tpc.action.QueriesCompare(dataGenMixin.getScale(), this.queries, this.excludedQueries, explain, iterations);
return mixin.runActions(ArrayUtils.addAll(dataGenMixin.makeActions(), queriesCompare));
}
}
diff --git a/tools/gluten-it/common/src/main/scala/io/glutenproject/integration/tpc/TpcSuite.scala b/tools/gluten-it/common/src/main/scala/io/glutenproject/integration/tpc/TpcSuite.scala
index 7ac8162a1de7..fa8e87fbb90c 100644
--- a/tools/gluten-it/common/src/main/scala/io/glutenproject/integration/tpc/TpcSuite.scala
+++ b/tools/gluten-it/common/src/main/scala/io/glutenproject/integration/tpc/TpcSuite.scala
@@ -179,3 +179,36 @@ abstract class TpcSuite(
private[tpc] def desc(): String
}
+
+object TpcSuite {
+ implicit class TpcSuiteImplicits(suite: TpcSuite) {
+ def selectQueryIds(queryIds: Array[String], excludedQueryIds: Array[String]): Array[String] = {
+ if (queryIds.nonEmpty && excludedQueryIds.nonEmpty) {
+ throw new IllegalArgumentException(
+ "Should not specify queries and excluded queries at the same time")
+ }
+ val all = suite.allQueryIds()
+ val allSet = all.toSet
+ if (queryIds.nonEmpty) {
+ assert(
+ queryIds.forall(id => allSet.contains(id)),
+ "Invalid query ID: " + queryIds.collectFirst {
+ case id if !allSet.contains(id)=>
+ id
+ }.get)
+ return queryIds
+ }
+ if (excludedQueryIds.nonEmpty) {
+ assert(
+ excludedQueryIds.forall(id => allSet.contains(id)),
+ "Invalid query ID to exclude: " + excludedQueryIds.collectFirst {
+ case id if !allSet.contains(id)=>
+ id
+ }.get)
+ val excludedSet = excludedQueryIds.toSet
+ return all.filterNot(excludedSet.contains)
+ }
+ all
+ }
+ }
+}
diff --git a/tools/gluten-it/common/src/main/scala/io/glutenproject/integration/tpc/action/Parameterized.scala b/tools/gluten-it/common/src/main/scala/io/glutenproject/integration/tpc/action/Parameterized.scala
index e8e989427c0e..cd3d97baa40e 100644
--- a/tools/gluten-it/common/src/main/scala/io/glutenproject/integration/tpc/action/Parameterized.scala
+++ b/tools/gluten-it/common/src/main/scala/io/glutenproject/integration/tpc/action/Parameterized.scala
@@ -31,6 +31,7 @@ import scala.collection.mutable.ArrayBuffer
class Parameterized(
scale: Double,
queryIds: Array[String],
+ excludedQueryIds: Array[String],
iterations: Int,
warmupIterations: Int,
configDimensions: Seq[Dim],
@@ -104,19 +105,7 @@ class Parameterized(
sessionSwitcher.registerSession(coordinate.toString, conf)
}
- val runQueryIds = queryIds match {
- case Array() =>
- allQueries
- case _ =>
- queryIds
- }
- val allQueriesSet = allQueries.toSet
- runQueryIds.foreach {
- queryId =>
- if (!allQueriesSet.contains(queryId)) {
- throw new IllegalArgumentException(s"Query ID doesn't exist: $queryId")
- }
- }
+ val runQueryIds = tpcSuite.selectQueryIds(queryIds, excludedQueryIds)
// warm up
(0 until warmupIterations).foreach {
diff --git a/tools/gluten-it/common/src/main/scala/io/glutenproject/integration/tpc/action/Queries.scala b/tools/gluten-it/common/src/main/scala/io/glutenproject/integration/tpc/action/Queries.scala
index e5a6e546f970..632e6489a3af 100644
--- a/tools/gluten-it/common/src/main/scala/io/glutenproject/integration/tpc/action/Queries.scala
+++ b/tools/gluten-it/common/src/main/scala/io/glutenproject/integration/tpc/action/Queries.scala
@@ -19,32 +19,32 @@ package io.glutenproject.integration.tpc.action
import io.glutenproject.integration.stat.RamStat
import io.glutenproject.integration.tpc.{TpcRunner, TpcSuite}
-import org.apache.spark.sql.SparkSessionSwitcher
-
import org.apache.commons.lang3.exception.ExceptionUtils
-case class Queries(scale: Double, queryIds: Array[String], explain: Boolean, iterations: Int, randomKillTasks: Boolean)
+case class Queries(
+ scale: Double,
+ queryIds: Array[String],
+ excludedQueryIds: Array[String],
+ explain: Boolean,
+ iterations: Int,
+ randomKillTasks: Boolean)
extends Action {
override def execute(tpcSuite: TpcSuite): Boolean = {
+ val runQueryIds = tpcSuite.selectQueryIds(queryIds, excludedQueryIds)
val runner: TpcRunner = new TpcRunner(tpcSuite.queryResource(), tpcSuite.dataWritePath(scale))
- val allQueries = tpcSuite.allQueryIds()
val results = (0 until iterations).flatMap {
iteration =>
println(s"Running tests (iteration $iteration)...")
- val runQueryIds = queryIds match {
- case Array() =>
- allQueries
- case _ =>
- queryIds
- }
- val allQueriesSet = allQueries.toSet
runQueryIds.map {
queryId =>
- if (!allQueriesSet.contains(queryId)) {
- throw new IllegalArgumentException(s"Query ID doesn't exist: $queryId")
- }
- Queries.runTpcQuery(runner, tpcSuite.sessionSwitcher, queryId, tpcSuite.desc(), explain, randomKillTasks)
+ Queries.runTpcQuery(
+ runner,
+ tpcSuite.sessionSwitcher,
+ queryId,
+ tpcSuite.desc(),
+ explain,
+ randomKillTasks)
}
}.toList
@@ -147,13 +147,24 @@ object Queries {
)))
}
- private def runTpcQuery(runner: _root_.io.glutenproject.integration.tpc.TpcRunner, sessionSwitcher: _root_.org.apache.spark.sql.SparkSessionSwitcher, id: _root_.java.lang.String, desc: _root_.java.lang.String, explain: Boolean, randomKillTasks: Boolean) = {
+ private def runTpcQuery(
+ runner: _root_.io.glutenproject.integration.tpc.TpcRunner,
+ sessionSwitcher: _root_.org.apache.spark.sql.SparkSessionSwitcher,
+ id: _root_.java.lang.String,
+ desc: _root_.java.lang.String,
+ explain: Boolean,
+ randomKillTasks: Boolean) = {
println(s"Running query: $id...")
try {
val testDesc = "Gluten Spark %s %s".format(desc, id)
sessionSwitcher.useSession("test", testDesc)
runner.createTables(sessionSwitcher.spark())
- val result = runner.runTpcQuery(sessionSwitcher.spark(), testDesc, id, explain = explain, randomKillTasks = randomKillTasks)
+ val result = runner.runTpcQuery(
+ sessionSwitcher.spark(),
+ testDesc,
+ id,
+ explain = explain,
+ randomKillTasks = randomKillTasks)
val resultRows = result.rows
println(
s"Successfully ran query $id. " +
diff --git a/tools/gluten-it/common/src/main/scala/io/glutenproject/integration/tpc/action/QueriesCompare.scala b/tools/gluten-it/common/src/main/scala/io/glutenproject/integration/tpc/action/QueriesCompare.scala
index 0e6beaa0fde6..47ba9eb54e89 100644
--- a/tools/gluten-it/common/src/main/scala/io/glutenproject/integration/tpc/action/QueriesCompare.scala
+++ b/tools/gluten-it/common/src/main/scala/io/glutenproject/integration/tpc/action/QueriesCompare.scala
@@ -23,27 +23,22 @@ import org.apache.spark.sql.{SparkSessionSwitcher, TestUtils}
import org.apache.commons.lang3.exception.ExceptionUtils
-case class QueriesCompare(scale: Double, queryIds: Array[String], explain: Boolean, iterations: Int)
+case class QueriesCompare(
+ scale: Double,
+ queryIds: Array[String],
+ excludedQueryIds: Array[String],
+ explain: Boolean,
+ iterations: Int)
extends Action {
override def execute(tpcSuite: TpcSuite): Boolean = {
val runner: TpcRunner = new TpcRunner(tpcSuite.queryResource(), tpcSuite.dataWritePath(scale))
- val allQueries = tpcSuite.allQueryIds()
+ val runQueryIds = tpcSuite.selectQueryIds(queryIds, excludedQueryIds)
val results = (0 until iterations).flatMap {
iteration =>
println(s"Running tests (iteration $iteration)...")
- val runQueryIds = queryIds match {
- case Array() =>
- allQueries
- case _ =>
- queryIds
- }
- val allQueriesSet = allQueries.toSet
runQueryIds.map {
queryId =>
- if (!allQueriesSet.contains(queryId)) {
- throw new IllegalArgumentException(s"Query ID doesn't exist: $queryId")
- }
QueriesCompare.runTpcQuery(
queryId,
explain,
@@ -194,6 +189,7 @@ object QueriesCompare {
val result = runner.runTpcQuery(sessionSwitcher.spark(), testDesc, id, explain = explain)
val resultRows = result.rows
val error = TestUtils.compareAnswers(resultRows, expectedRows, sort = true)
+ // FIXME: This is too hacky
// A list of query ids whose corresponding query results can differ because of order.
val unorderedQueries = Seq("q65")
if (error.isEmpty || unorderedQueries.contains(id)) {
diff --git a/tools/gluten-it/common/src/main/scala/org/apache/spark/sql/QueryRunner.scala b/tools/gluten-it/common/src/main/scala/org/apache/spark/sql/QueryRunner.scala
index 400c49f6703d..a4044c925a31 100644
--- a/tools/gluten-it/common/src/main/scala/org/apache/spark/sql/QueryRunner.scala
+++ b/tools/gluten-it/common/src/main/scala/org/apache/spark/sql/QueryRunner.scala
@@ -81,10 +81,10 @@ object QueryRunner {
val sql = resourceToString(queryPath)
val prev = System.nanoTime()
val df = spark.sql(sql)
+ val rows = df.collect()
if (explain) {
df.explain(extended = true)
}
- val rows = df.collect()
val millis = (System.nanoTime() - prev) / 1000000L
val collectedMetrics = metrics.map(name => (name, em.getMetricValue(name))).toMap
RunResult(rows, millis, collectedMetrics)