diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala index 7d2ab84d3a574..379d9c0ba0c9d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala @@ -159,6 +159,13 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] { None } + // Handle String to Date cast: CAST(string_col AS DATE) = date_literal + // Transform to: string_col = CAST(date_literal AS STRING) + case be @ BinaryComparison( + Cast(fromExp, DateType, _, _), date @ Literal(value, DateType)) + if fromExp.dataType == StringType && value != null => + Some(unwrapStringToDate(be, fromExp, date)) + // As the analyzer makes sure that the list of In is already of the same data type, then the // rule can simply check the first literal in `in.list` can implicitly cast to `toType` or not, // and note that: @@ -412,6 +419,28 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] { } } + /** + * Move the cast to the literal side for String to Date comparisons. + * Transform CAST(string_col AS DATE) op date_literal to + * string_col op CAST(date_literal AS STRING). + * This allows the comparison to be pushed down to data sources and avoids casting every row. + */ + private def unwrapStringToDate( + exp: BinaryComparison, + fromExp: Expression, + date: Literal): Expression = { + val dateAsString = Cast(date, StringType) + exp match { + case _: GreaterThan => GreaterThan(fromExp, dateAsString) + case _: GreaterThanOrEqual => GreaterThanOrEqual(fromExp, dateAsString) + case _: EqualTo => EqualTo(fromExp, dateAsString) + case _: EqualNullSafe => EqualNullSafe(fromExp, dateAsString) + case _: LessThan => LessThan(fromExp, dateAsString) + case _: LessThanOrEqual => LessThanOrEqual(fromExp, dateAsString) + case _ => exp + } + } + private def simplifyIn[IN <: Expression]( fromExp: Expression, toType: NumericType, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala index 9e8e31c69c3c2..138b8053239f1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala @@ -43,7 +43,7 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest with ExpressionEvalHelp } val testRelation: LocalRelation = LocalRelation($"a".short, $"b".float, - $"c".decimal(5, 2), $"d".boolean, $"e".timestamp, $"f".timestampNTZ, $"g".date) + $"c".decimal(5, 2), $"d".boolean, $"e".timestamp, $"f".timestampNTZ, $"g".date, $"h".string) val f: BoundReference = $"a".short.canBeNull.at(0) val f2: BoundReference = $"b".float.canBeNull.at(1) val f3: BoundReference = $"c".decimal(5, 2).canBeNull.at(2) @@ -51,6 +51,7 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest with ExpressionEvalHelp val f5: BoundReference = $"e".timestamp.notNull.at(4) val f6: BoundReference = $"f".timestampNTZ.canBeNull.at(5) val f7: BoundReference = $"g".date.canBeNull.at(6) + val f8: BoundReference = $"h".string.canBeNull.at(7) test("unwrap casts when literal == max") { val v = Short.MaxValue @@ -502,6 +503,49 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest with ExpressionEvalHelp doTest(tsNtzLit4) } + test("Support unwrap String to Date cast") { + val date = java.sql.Date.valueOf("2023-01-01") + val dateLit = Literal.create(date, DateType) + + // Helper to check that cast has been moved from column to literal + def checkOptimized(original: Expression): Unit = { + val plan = testRelation.where(original).analyze + val optimized = Optimize.execute(plan) + + // Extract the filter condition from the optimized plan + val filter = optimized.collectFirst { + case f: org.apache.spark.sql.catalyst.plans.logical.Filter => f.condition + }.get + + // The optimized expression should: + // 1. Not have a Cast on the string column (f8) + // 2. Have a Cast on the date literal side + filter match { + case bc: BinaryComparison => + // Left side should be the string column without cast + assert(bc.left == f8, s"Expected left side to be f8, but got ${bc.left}") + // Right side should be a Cast from Date to String + bc.right match { + case Cast(lit: Literal, StringType, _, _) => + assert(lit.dataType == DateType, + s"Expected Cast from DateType, but got ${lit.dataType}") + case other => + fail(s"Expected Cast on right side, but got $other") + } + case other => + fail(s"Expected BinaryComparison, but got $other") + } + } + + // Test all comparison operators + checkOptimized(castStringToDate(f8) > dateLit) + checkOptimized(castStringToDate(f8) >= dateLit) + checkOptimized(castStringToDate(f8) === dateLit) + checkOptimized(castStringToDate(f8) <=> dateLit) + checkOptimized(castStringToDate(f8) < dateLit) + checkOptimized(castStringToDate(f8) <= dateLit) + } + private val ts1 = LocalDateTime.of(2023, 1, 1, 23, 59, 59, 99999000) private val ts2 = LocalDateTime.of(2023, 1, 1, 23, 59, 59, 999998000) private val ts3 = LocalDateTime.of(9999, 12, 31, 23, 59, 59, 999999999) @@ -518,6 +562,7 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest with ExpressionEvalHelp Cast(e, TimestampType, Some(conf.sessionLocalTimeZone)) private def castTimestampNTZ(e: Expression): Expression = Cast(e, TimestampNTZType, Some(conf.sessionLocalTimeZone)) + private def castStringToDate(e: Expression): Expression = Cast(e, DateType) private def decimal(v: Decimal): Decimal = Decimal(v.toJavaBigDecimal, 5, 2) private def decimal2(v: BigDecimal): Decimal = Decimal(v, 10, 4) @@ -530,16 +575,19 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest with ExpressionEvalHelp if (evaluate) { Seq( - (100.toShort, 3.14.toFloat, decimal2(100), true, ts1, ts1, dt1), - (-300.toShort, 3.1415927.toFloat, decimal2(-3000.50), false, ts2, ts2, dt1), - (null, Float.NaN, decimal2(12345.6789), null, null, null, null), - (null, null, null, null, null, null, null), - (Short.MaxValue, Float.PositiveInfinity, decimal2(Short.MaxValue), true, ts3, ts3, dt1), - (Short.MinValue, Float.NegativeInfinity, decimal2(Short.MinValue), false, ts4, ts4, dt1), - (0.toShort, Float.MaxValue, decimal2(0), null, null, null, null), - (0.toShort, Float.MinValue, decimal2(0.01), null, null, null, null) + (100.toShort, 3.14.toFloat, decimal2(100), true, ts1, ts1, dt1, "2023-01-01"), + (-300.toShort, 3.1415927.toFloat, decimal2(-3000.50), false, ts2, ts2, dt1, + "2023-12-31"), + (null, Float.NaN, decimal2(12345.6789), null, null, null, null, null), + (null, null, null, null, null, null, null, null), + (Short.MaxValue, Float.PositiveInfinity, decimal2(Short.MaxValue), true, ts3, ts3, dt1, + "2023-06-15"), + (Short.MinValue, Float.NegativeInfinity, decimal2(Short.MinValue), false, ts4, ts4, dt1, + "2000-01-01"), + (0.toShort, Float.MaxValue, decimal2(0), null, null, null, null, "2023-01-01"), + (0.toShort, Float.MinValue, decimal2(0.01), null, null, null, null, "1999-12-31") ).foreach(v => { - val row = create_row(v._1, v._2, v._3, v._4, v._5, v._6, v._7) + val row = create_row(v._1, v._2, v._3, v._4, v._5, v._6, v._7, v._8) checkEvaluation(e1, e2.eval(row), row) }) }