Skip to content

[SPARK-52060][SQL] Make OneRowRelationExec node #50849

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -319,3 +319,51 @@ case class RDDScanExec(

override def getStream: Option[SparkDataStream] = stream
}

/**
* A special case of RDDScanExec that is used to represent a scan without a `FROM` clause.
* For example, 'select version()'.
*
* We do not extend `RDDScanExec` in order to avoid complexity due to `TreeNode.makeCopy` and
* `TreeNode`'s general use of reflection.
*/
case class OneRowRelationExec() extends LeafExecNode
with StreamSourceAwareSparkPlan
with InputRDDCodegen {

override val nodeName: String = s"Scan OneRowRelation"

override val output: Seq[Attribute] = Nil

val rdd = session.sparkContext.parallelize(Seq(InternalRow()), 1)

override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))

protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
rdd.mapPartitionsWithIndexInternal { (index, iter) =>
val proj = UnsafeProjection.create(schema)
proj.initialize(index)
iter.map { r =>
numOutputRows += 1
proj(r)
}
}
}

override def simpleString(maxFields: Int): String = {
s"$nodeName${truncatedString(output, "[", ",", "]", maxFields)}"
}

override def inputRDD: RDD[InternalRow] = rdd

// Input can be InternalRow, has to be turned into UnsafeRows.
override protected val createUnsafeProjection: Boolean = true

override protected def doCanonicalize(): SparkPlan = {
super.doCanonicalize().asInstanceOf[OneRowRelationExec].copy()
}

override def getStream: Option[SparkDataStream] = None
}
Original file line number Diff line number Diff line change
Expand Up @@ -690,8 +690,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
}

protected lazy val singleRowRdd = session.sparkContext.parallelize(Seq(InternalRow()), 1)

object InMemoryScans extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case PhysicalOperation(projectList, filters, mem: InMemoryRelation) =>
Expand Down Expand Up @@ -1040,7 +1038,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
generator, g.requiredChildOutput, outer,
g.qualifiedGeneratorOutput, planLater(child)) :: Nil
case _: logical.OneRowRelation =>
execution.RDDScanExec(Nil, singleRowRdd, "OneRowRelation") :: Nil
execution.OneRowRelationExec() :: Nil
case r: logical.Range =>
execution.RangeExec(r) :: Nil
case r: logical.RepartitionByExpression =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ trait CodegenSupport extends SparkPlan {
case _: SortMergeJoinExec => "smj"
case _: BroadcastNestedLoopJoinExec => "bnlj"
case _: RDDScanExec => "rdd"
case _: OneRowRelationExec => "orr"
case _: DataSourceScanExec => "scan"
case _: InMemoryTableScanExec => "memoryScan"
case _: WholeStageCodegenExec => "wholestagecodegen"
Expand Down
14 changes: 13 additions & 1 deletion sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.optimizer.{ConvertToLocalRelation, NestedCo
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.plans.logical.{LocalLimit, Project, RepartitionByExpression, Sort}
import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME
import org.apache.spark.sql.execution.{CommandResultExec, UnionExec}
import org.apache.spark.sql.execution.{CommandResultExec, OneRowRelationExec, UnionExec}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.aggregate._
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
Expand Down Expand Up @@ -4962,6 +4962,18 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
parameters = Map("plan" -> "'Aggregate [groupingsets(Vector(0), posexplode(array(col)))]")
)
}

Seq(true, false).foreach { codegenEnabled =>
test(s"SPARK-52060: one row relation with codegen enabled - $codegenEnabled") {
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegenEnabled.toString) {
val df = spark.sql("select 'test' stringCol")
checkAnswer(df, Row("test"))
val plan = df.queryExecution.executedPlan
val oneRowRelationExists = plan.find(_.isInstanceOf[OneRowRelationExec]).isDefined
assert(oneRowRelationExists)
}
}
}
}

case class Foo(bar: Option[String])