Skip to content

Commit 707459a

Browse files
committed
temp
1 parent 426accf commit 707459a

4 files changed

Lines changed: 440 additions & 8 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,13 @@ object CollationTypeCoercion extends SQLConfHelper {
8181
getMap
8282
}
8383

84+
case elemAt @ ElementAt(left, key @ Literal(_, _: StringType), _, _)
85+
if left.dataType.isInstanceOf[MapType] &&
86+
left.dataType.asInstanceOf[MapType].keyType != key.dataType =>
87+
elemAt.copy(right =
88+
Cast(key, left.dataType.asInstanceOf[MapType].keyType,
89+
timeZoneId = Some(conf.sessionLocalTimeZone)))
90+
8491
case otherExpr @ (_: In | _: InSubquery | _: CreateArray | _: ArrayJoin | _: Concat |
8592
_: Greatest | _: Least | _: Coalesce | _: ArrayContains | _: ArrayExcept | _: ConcatWs |
8693
_: Mask | _: StringReplace | _: StringTranslate | _: StringTrim | _: StringTrimLeft |
@@ -288,6 +295,16 @@ object CollationTypeCoercion extends SQLConfHelper {
288295
None
289296
}
290297

298+
case elementAt: ElementAt =>
299+
findCollationContext(elementAt.left) match {
300+
case Some(MapType(_, valueType, _)) =>
301+
mergeWinner(elementAt.dataType, valueType)
302+
case Some(ArrayType(elementType, _)) =>
303+
mergeWinner(elementAt.dataType, elementType)
304+
case _ =>
305+
None
306+
}
307+
291308
case struct: CreateNamedStruct =>
292309
val childrenContexts = struct.valExprs.map(findCollationContext)
293310
if (childrenContexts.isEmpty) {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParameterHandler.scala

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
package org.apache.spark.sql.catalyst.parser
1818

1919
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal}
20+
import org.apache.spark.sql.catalyst.types.DataTypeUtils
21+
import org.apache.spark.sql.util.SchemaUtils
2022

2123
/**
2224
* Handler for parameter substitution across different Spark SQL contexts.
@@ -107,14 +109,37 @@ object ParameterHandler {
107109
* @param expr The expression to convert (must be a Literal)
108110
* @return SQL string representation
109111
*/
110-
private def convertToSql(expr: Expression): String = expr match {
111-
case lit: Literal => lit.sql
112-
case other =>
113-
throw new IllegalArgumentException(
114-
s"ParameterHandler only accepts resolved Literal expressions. " +
115-
s"Received: ${other.getClass.getSimpleName}. " +
116-
s"All parameters must be resolved using SparkSession.resolveAndValidateParameters " +
117-
s"before being passed to the pre-parser.")
112+
private def convertToSql(expr: Expression): String = {
113+
// Converts an expression to its SQL representation. If the expression's type contains collated
114+
// types, strips collations from nested literals and wraps the whole expression in
115+
// CAST to preserve the collation with implicit strength. Without this, Literal.sql
116+
// produces `'value' COLLATE collationName` which re-parses with explicit strength.
117+
def toSqlWithImplicitCollation(e: Expression): String = {
118+
if (!DataTypeUtils.hasNonDefaultStringCharOrVarcharType(e.dataType)) {
119+
e.sql
120+
} else {
121+
val stripped = e.transform {
122+
case lit: Literal
123+
if DataTypeUtils.hasNonDefaultStringCharOrVarcharType(lit.dataType) =>
124+
Literal.create(
125+
lit.value, SchemaUtils.replaceCollatedStringWithString(lit.dataType))
126+
}
127+
s"CAST(${stripped.sql} AS ${e.dataType.sql})"
128+
}
129+
}
130+
131+
expr match {
132+
case lit: Literal if lit.value == null =>
133+
lit.sql
134+
case lit: Literal =>
135+
toSqlWithImplicitCollation(lit)
136+
case other =>
137+
throw new IllegalArgumentException(
138+
s"ParameterHandler only accepts resolved Literal expressions. " +
139+
s"Received: ${other.getClass.getSimpleName}. " +
140+
s"All parameters must be resolved using SparkSession.resolveAndValidateParameters " +
141+
s"before being passed to the pre-parser.")
142+
}
118143
}
119144

120145
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,17 @@ object DataTypeUtils {
302302
}
303303
}
304304

305+
/**
306+
* Returns true if the given data type contains any STRING/CHAR/VARCHAR with explicit collation
307+
* (including explicit `UTF8_BINARY`), recursively checking nested types.
308+
*/
309+
def hasNonDefaultStringCharOrVarcharType(dataType: DataType): Boolean = {
310+
dataType.existsRecursively {
311+
case st: StringType => !isDefaultStringCharOrVarcharType(st)
312+
case _ => false
313+
}
314+
}
315+
305316
/**
306317
* Recursively replaces all STRING, CHAR and VARCHAR types that do not have an explicit collation
307318
* with the same type but with explicit `UTF8_BINARY` collation.

0 commit comments

Comments
 (0)