Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,13 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] {
None
}

// Handle String to Date cast: CAST(string_col AS DATE) = date_literal
// Transform to: string_col = CAST(date_literal AS STRING)
case be @ BinaryComparison(
Cast(fromExp, DateType, _, _), date @ Literal(value, DateType))
if fromExp.dataType == StringType && value != null =>
Some(unwrapStringToDate(be, fromExp, date))

// As the analyzer makes sure that the list of In is already of the same data type, then the
// rule can simply check the first literal in `in.list` can implicitly cast to `toType` or not,
// and note that:
Expand Down Expand Up @@ -412,6 +419,28 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] {
}
}

/**
* Move the cast to the literal side for String to Date comparisons.
* Transform CAST(string_col AS DATE) op date_literal to
* string_col op CAST(date_literal AS STRING).
* This allows the comparison to be pushed down to data sources and avoids casting every row.
*/
private def unwrapStringToDate(
exp: BinaryComparison,
fromExp: Expression,
date: Literal): Expression = {
val dateAsString = Cast(date, StringType)
exp match {
case _: GreaterThan => GreaterThan(fromExp, dateAsString)
case _: GreaterThanOrEqual => GreaterThanOrEqual(fromExp, dateAsString)
case _: EqualTo => EqualTo(fromExp, dateAsString)
case _: EqualNullSafe => EqualNullSafe(fromExp, dateAsString)
case _: LessThan => LessThan(fromExp, dateAsString)
case _: LessThanOrEqual => LessThanOrEqual(fromExp, dateAsString)
case _ => exp
}
}

private def simplifyIn[IN <: Expression](
fromExp: Expression,
toType: NumericType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,15 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest with ExpressionEvalHelp
}

val testRelation: LocalRelation = LocalRelation($"a".short, $"b".float,
$"c".decimal(5, 2), $"d".boolean, $"e".timestamp, $"f".timestampNTZ, $"g".date)
$"c".decimal(5, 2), $"d".boolean, $"e".timestamp, $"f".timestampNTZ, $"g".date, $"h".string)
val f: BoundReference = $"a".short.canBeNull.at(0)
val f2: BoundReference = $"b".float.canBeNull.at(1)
val f3: BoundReference = $"c".decimal(5, 2).canBeNull.at(2)
val f4: BoundReference = $"d".boolean.canBeNull.at(3)
val f5: BoundReference = $"e".timestamp.notNull.at(4)
val f6: BoundReference = $"f".timestampNTZ.canBeNull.at(5)
val f7: BoundReference = $"g".date.canBeNull.at(6)
val f8: BoundReference = $"h".string.canBeNull.at(7)

test("unwrap casts when literal == max") {
val v = Short.MaxValue
Expand Down Expand Up @@ -502,6 +503,49 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest with ExpressionEvalHelp
doTest(tsNtzLit4)
}

test("Support unwrap String to Date cast") {
val date = java.sql.Date.valueOf("2023-01-01")
val dateLit = Literal.create(date, DateType)

// Helper to check that cast has been moved from column to literal
def checkOptimized(original: Expression): Unit = {
val plan = testRelation.where(original).analyze
val optimized = Optimize.execute(plan)

// Extract the filter condition from the optimized plan
val filter = optimized.collectFirst {
case f: org.apache.spark.sql.catalyst.plans.logical.Filter => f.condition
}.get

// The optimized expression should:
// 1. Not have a Cast on the string column (f8)
// 2. Have a Cast on the date literal side
filter match {
case bc: BinaryComparison =>
// Left side should be the string column without cast
assert(bc.left == f8, s"Expected left side to be f8, but got ${bc.left}")
// Right side should be a Cast from Date to String
bc.right match {
case Cast(lit: Literal, StringType, _, _) =>
assert(lit.dataType == DateType,
s"Expected Cast from DateType, but got ${lit.dataType}")
case other =>
fail(s"Expected Cast on right side, but got $other")
}
case other =>
fail(s"Expected BinaryComparison, but got $other")
}
}

// Test all comparison operators
checkOptimized(castStringToDate(f8) > dateLit)
checkOptimized(castStringToDate(f8) >= dateLit)
checkOptimized(castStringToDate(f8) === dateLit)
checkOptimized(castStringToDate(f8) <=> dateLit)
checkOptimized(castStringToDate(f8) < dateLit)
checkOptimized(castStringToDate(f8) <= dateLit)
}

private val ts1 = LocalDateTime.of(2023, 1, 1, 23, 59, 59, 99999000)
private val ts2 = LocalDateTime.of(2023, 1, 1, 23, 59, 59, 999998000)
private val ts3 = LocalDateTime.of(9999, 12, 31, 23, 59, 59, 999999999)
Expand All @@ -518,6 +562,7 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest with ExpressionEvalHelp
Cast(e, TimestampType, Some(conf.sessionLocalTimeZone))
private def castTimestampNTZ(e: Expression): Expression =
Cast(e, TimestampNTZType, Some(conf.sessionLocalTimeZone))
private def castStringToDate(e: Expression): Expression = Cast(e, DateType)

private def decimal(v: Decimal): Decimal = Decimal(v.toJavaBigDecimal, 5, 2)
private def decimal2(v: BigDecimal): Decimal = Decimal(v, 10, 4)
Expand All @@ -530,16 +575,19 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest with ExpressionEvalHelp

if (evaluate) {
Seq(
(100.toShort, 3.14.toFloat, decimal2(100), true, ts1, ts1, dt1),
(-300.toShort, 3.1415927.toFloat, decimal2(-3000.50), false, ts2, ts2, dt1),
(null, Float.NaN, decimal2(12345.6789), null, null, null, null),
(null, null, null, null, null, null, null),
(Short.MaxValue, Float.PositiveInfinity, decimal2(Short.MaxValue), true, ts3, ts3, dt1),
(Short.MinValue, Float.NegativeInfinity, decimal2(Short.MinValue), false, ts4, ts4, dt1),
(0.toShort, Float.MaxValue, decimal2(0), null, null, null, null),
(0.toShort, Float.MinValue, decimal2(0.01), null, null, null, null)
(100.toShort, 3.14.toFloat, decimal2(100), true, ts1, ts1, dt1, "2023-01-01"),
(-300.toShort, 3.1415927.toFloat, decimal2(-3000.50), false, ts2, ts2, dt1,
"2023-12-31"),
(null, Float.NaN, decimal2(12345.6789), null, null, null, null, null),
(null, null, null, null, null, null, null, null),
(Short.MaxValue, Float.PositiveInfinity, decimal2(Short.MaxValue), true, ts3, ts3, dt1,
"2023-06-15"),
(Short.MinValue, Float.NegativeInfinity, decimal2(Short.MinValue), false, ts4, ts4, dt1,
"2000-01-01"),
(0.toShort, Float.MaxValue, decimal2(0), null, null, null, null, "2023-01-01"),
(0.toShort, Float.MinValue, decimal2(0.01), null, null, null, null, "1999-12-31")
).foreach(v => {
val row = create_row(v._1, v._2, v._3, v._4, v._5, v._6, v._7)
val row = create_row(v._1, v._2, v._3, v._4, v._5, v._6, v._7, v._8)
checkEvaluation(e1, e2.eval(row), row)
})
}
Expand Down