Skip to content

Commit a8111b2

Browse files
committed
[SPARK-52875][SQL] Simplify V2 expression translation if the input is context-independent-foldable
### What changes were proposed in this pull request? If the input to V2 expression translation is context-independent and foldable (see [PR #51282](#51282)), we perform constant folding on the input and use the evaluated result. For example: * `1 + 1` becomes 2 * `a < log2(8)` becomes a < 3.0 ### Why are the changes needed? This change broadens the coverage of V2 expression translation. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? New UT ### Was this patch authored or co-authored using generative AI tooling? No Closes #51569 from gengliangwang/v2Fold. Authored-by: Gengliang Wang <[email protected]> Signed-off-by: Gengliang Wang <[email protected]>
1 parent 5b622c5 commit a8111b2

File tree

7 files changed

+276
-188
lines changed

7 files changed

+276
-188
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,8 @@ case class CaseWhen(
200200
branches.exists(_._2.nullable) || elseValue.map(_.nullable).getOrElse(true)
201201
}
202202

203+
override def contextIndependentFoldable: Boolean = children.forall(_.contextIndependentFoldable)
204+
203205
override def checkInputDataTypes(): TypeCheckResult = {
204206
if (TypeCoercion.haveSameType(inputTypesForMerging)) {
205207
// Make sure all branch conditions are boolean types.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ object ConstantFolding extends Rule[LogicalPlan] {
5959
case _ => false
6060
}
6161

62-
private def constantFolding(
62+
private[sql] def constantFolding(
6363
e: Expression,
6464
isConditionalBranch: Boolean = false): Expression = e match {
6565
case c: ConditionalExpression if !c.foldable =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import org.apache.spark.internal.LogKeys.EXPR
2222
import org.apache.spark.sql.catalyst.expressions._
2323
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Complete}
2424
import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke}
25+
import org.apache.spark.sql.catalyst.optimizer.ConstantFolding
2526
import org.apache.spark.sql.connector.catalog.functions.ScalarFunction
2627
import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, Extract => V2Extract, FieldReference, GeneralScalarExpression, LiteralValue, NullOrdering, SortDirection, SortValue, UserDefinedScalarFunc}
2728
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum, UserDefinedAggregateFunc}
@@ -83,11 +84,20 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) extends L
8384
case _ => false
8485
}
8586

87+
private def translateLiteral(l: Literal): V2Expression = l match {
88+
case Literal(true, BooleanType) => new AlwaysTrue()
89+
case Literal(false, BooleanType) => new AlwaysFalse()
90+
case other => LiteralValue(other.value, other.dataType)
91+
}
92+
8693
private def generateExpression(
8794
expr: Expression, isPredicate: Boolean = false): Option[V2Expression] = expr match {
88-
case Literal(true, BooleanType) => Some(new AlwaysTrue())
89-
case Literal(false, BooleanType) => Some(new AlwaysFalse())
90-
case Literal(value, dataType) => Some(LiteralValue(value, dataType))
95+
case literal: Literal => Some(translateLiteral(literal))
96+
case _ if expr.contextIndependentFoldable =>
97+
// If the expression is context independent foldable, we can convert it to a literal.
98+
// This is useful for increasing the coverage of V2 expressions.
99+
val constantExpr = ConstantFolding.constantFolding(expr)
100+
generateExpression(constantExpr, isPredicate)
91101
case col @ ColumnOrField(nameParts) =>
92102
val ref = FieldReference(nameParts)
93103
if (isPredicate && col.dataType.isInstanceOf[BooleanType]) {

sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala

Lines changed: 18 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
2828
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, LogicalPlan, ReplaceTableAsSelect}
2929
import org.apache.spark.sql.connector.catalog.{Column, ColumnDefaultValue, DefaultValue, Identifier, InMemoryTableCatalog, TableInfo}
3030
import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, UpdateColumnDefaultValue}
31-
import org.apache.spark.sql.connector.expressions.{ApplyTransform, Cast => V2Cast, GeneralScalarExpression, LiteralValue, Transform}
31+
import org.apache.spark.sql.connector.expressions.{ApplyTransform, GeneralScalarExpression, LiteralValue, Transform}
32+
import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue}
3233
import org.apache.spark.sql.execution.{QueryExecution, SparkPlan}
3334
import org.apache.spark.sql.execution.ExplainUtils.stripAQEPlan
3435
import org.apache.spark.sql.execution.datasources.v2.{AlterTableExec, CreateTableExec, DataSourceV2Relation, ReplaceTableExec}
@@ -371,21 +372,15 @@ class DataSourceV2DataFrameSuite
371372
null,
372373
new ColumnDefaultValue(
373374
"(100 + 23)",
374-
new GeneralScalarExpression(
375-
"+",
376-
Array(LiteralValue(100, IntegerType), LiteralValue(23, IntegerType))),
375+
LiteralValue(123, IntegerType),
377376
LiteralValue(123, IntegerType)),
378377
new ColumnDefaultValue(
379378
"('h' || 'r')",
380-
new GeneralScalarExpression(
381-
"CONCAT",
382-
Array(
383-
LiteralValue(UTF8String.fromString("h"), StringType),
384-
LiteralValue(UTF8String.fromString("r"), StringType))),
379+
LiteralValue(UTF8String.fromString("hr"), StringType),
385380
LiteralValue(UTF8String.fromString("hr"), StringType)),
386381
new ColumnDefaultValue(
387382
"CAST(1 AS BOOLEAN)",
388-
new V2Cast(LiteralValue(1, IntegerType), IntegerType, BooleanType),
383+
new AlwaysTrue,
389384
LiteralValue(true, BooleanType))))
390385

391386
val df1 = Seq(1).toDF("id")
@@ -420,21 +415,15 @@ class DataSourceV2DataFrameSuite
420415
null,
421416
new ColumnDefaultValue(
422417
"(50 * 2)",
423-
new GeneralScalarExpression(
424-
"*",
425-
Array(LiteralValue(50, IntegerType), LiteralValue(2, IntegerType))),
418+
LiteralValue(100, IntegerType),
426419
LiteralValue(100, IntegerType)),
427420
new ColumnDefaultValue(
428421
"('un' || 'known')",
429-
new GeneralScalarExpression(
430-
"CONCAT",
431-
Array(
432-
LiteralValue(UTF8String.fromString("un"), StringType),
433-
LiteralValue(UTF8String.fromString("known"), StringType))),
422+
LiteralValue(UTF8String.fromString("unknown"), StringType),
434423
LiteralValue(UTF8String.fromString("unknown"), StringType)),
435424
new ColumnDefaultValue(
436425
"CAST(0 AS BOOLEAN)",
437-
new V2Cast(LiteralValue(0, IntegerType), IntegerType, BooleanType),
426+
new AlwaysFalse,
438427
LiteralValue(false, BooleanType))))
439428

440429
val df3 = Seq(1).toDF("id")
@@ -469,21 +458,15 @@ class DataSourceV2DataFrameSuite
469458
Array(
470459
new ColumnDefaultValue(
471460
"(100 + 23)",
472-
new GeneralScalarExpression(
473-
"+",
474-
Array(LiteralValue(100, IntegerType), LiteralValue(23, IntegerType))),
461+
LiteralValue(123, IntegerType),
475462
LiteralValue(123, IntegerType)),
476463
new ColumnDefaultValue(
477464
"('h' || 'r')",
478-
new GeneralScalarExpression(
479-
"CONCAT",
480-
Array(
481-
LiteralValue(UTF8String.fromString("h"), StringType),
482-
LiteralValue(UTF8String.fromString("r"), StringType))),
465+
LiteralValue(UTF8String.fromString("hr"), StringType),
483466
LiteralValue(UTF8String.fromString("hr"), StringType)),
484467
new ColumnDefaultValue(
485468
"CAST(1 AS BOOLEAN)",
486-
new V2Cast(LiteralValue(1, IntegerType), IntegerType, BooleanType),
469+
new AlwaysTrue,
487470
LiteralValue(true, BooleanType))))
488471
}
489472
}
@@ -514,19 +497,13 @@ class DataSourceV2DataFrameSuite
514497
Array(
515498
new DefaultValue(
516499
"(123 + 56)",
517-
new GeneralScalarExpression(
518-
"+",
519-
Array(LiteralValue(123, IntegerType), LiteralValue(56, IntegerType)))),
500+
LiteralValue(179, IntegerType)),
520501
new DefaultValue(
521502
"('r' || 'l')",
522-
new GeneralScalarExpression(
523-
"CONCAT",
524-
Array(
525-
LiteralValue(UTF8String.fromString("r"), StringType),
526-
LiteralValue(UTF8String.fromString("l"), StringType)))),
503+
LiteralValue(UTF8String.fromString("rl"), StringType)),
527504
new DefaultValue(
528505
"CAST(0 AS BOOLEAN)",
529-
new V2Cast(LiteralValue(0, IntegerType), IntegerType, BooleanType))))
506+
new AlwaysFalse)))
530507
}
531508
}
532509

@@ -692,7 +669,7 @@ class DataSourceV2DataFrameSuite
692669
LiteralValue(1542490413000000L, TimestampType)),
693670
new ColumnDefaultValue(
694671
"1",
695-
new V2Cast(LiteralValue(1, IntegerType), IntegerType, DoubleType),
672+
LiteralValue(1.0, DoubleType),
696673
LiteralValue(1.0, DoubleType))))
697674

698675
val replaceExec = executeAndKeepPhysicalPlan[ReplaceTableExec] {
@@ -714,11 +691,7 @@ class DataSourceV2DataFrameSuite
714691
LiteralValue(1645624555000000L, TimestampType)),
715692
new ColumnDefaultValue(
716693
"(1 + 1)",
717-
new V2Cast(
718-
new GeneralScalarExpression("+", Array(LiteralValue(1, IntegerType),
719-
LiteralValue(1, IntegerType))),
720-
IntegerType,
721-
DoubleType),
694+
LiteralValue(2.0, DoubleType),
722695
LiteralValue(2.0, DoubleType))))
723696
}
724697
}
@@ -746,7 +719,7 @@ class DataSourceV2DataFrameSuite
746719
LiteralValue(1542490413000000L, TimestampType)),
747720
new ColumnDefaultValue(
748721
"1",
749-
new V2Cast(LiteralValue(1, IntegerType), IntegerType, DoubleType),
722+
LiteralValue(1.0, DoubleType),
750723
LiteralValue(1.0, DoubleType))))
751724

752725
val alterCol1 = executeAndKeepPhysicalPlan[AlterTableExec] {
@@ -764,11 +737,7 @@ class DataSourceV2DataFrameSuite
764737
LiteralValue(1645624555000000L, TimestampType)),
765738
new DefaultValue(
766739
"(1 + 1)",
767-
new V2Cast(
768-
new GeneralScalarExpression("+", Array(LiteralValue(1, IntegerType),
769-
LiteralValue(1, IntegerType))),
770-
IntegerType,
771-
DoubleType))))
740+
LiteralValue(2.0, DoubleType))))
772741
}
773742
}
774743

sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils
3838
import org.apache.spark.sql.connector.catalog.{Column => ColumnV2, _}
3939
import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME
4040
import org.apache.spark.sql.connector.catalog.CatalogV2Util.withDefaultOwnership
41-
import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression, GeneralScalarExpression, LiteralValue, Transform}
41+
import org.apache.spark.sql.connector.expressions.{LiteralValue, Transform}
4242
import org.apache.spark.sql.errors.QueryErrorsBase
4343
import org.apache.spark.sql.execution.FilterExec
4444
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
@@ -655,12 +655,7 @@ class DataSourceV2SQLSuiteV1Filter
655655
null, /* no comment */
656656
new ColumnDefaultValue(
657657
"41 + 1",
658-
new V2Cast(
659-
new GeneralScalarExpression(
660-
"+",
661-
Array[Expression](LiteralValue(41, IntegerType), LiteralValue(1, IntegerType))),
662-
IntegerType,
663-
LongType),
658+
LiteralValue(42L, LongType),
664659
LiteralValue(42L, LongType)),
665660
null /* no metadata */)
666661
assert(actual === expected,

sql/core/src/test/scala/org/apache/spark/sql/connector/PushablePredicateSuite.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import org.apache.spark.sql.connector.expressions.filter.{AlwaysTrue, Predicate
2424
import org.apache.spark.sql.execution.datasources.v2.PushablePredicate
2525
import org.apache.spark.sql.internal.SQLConf
2626
import org.apache.spark.sql.test.SharedSparkSession
27-
import org.apache.spark.sql.types.BooleanType
27+
import org.apache.spark.sql.types.{BooleanType, TimestampType}
2828

2929
class PushablePredicateSuite extends QueryTest with SharedSparkSession {
3030

@@ -70,7 +70,8 @@ class PushablePredicateSuite extends QueryTest with SharedSparkSession {
7070
withSQLConf(
7171
SQLConf.DATA_SOURCE_ALWAYS_CREATE_V2_PREDICATE.key -> createV2Predicate.toString,
7272
SQLConf.DATA_SOURCE_DONT_ASSERT_ON_PREDICATE.key -> noAssert.toString) {
73-
val catalystExpr = Cast(Literal.create("true"), BooleanType)
73+
val catalystExpr =
74+
Cast(Cast(Literal.create("2025-01-01 00:00:00"), TimestampType), BooleanType)
7475
if (createV2Predicate) {
7576
val pushable = PushablePredicate.unapply(catalystExpr)
7677
assert(pushable.isDefined)

0 commit comments

Comments
 (0)