Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,16 @@ object CollationTypeCoercion extends SQLConfHelper {
None
}

case elementAt: ElementAt =>
findCollationContext(elementAt.left) match {
case Some(MapType(_, valueType, _)) =>
mergeWinner(elementAt.dataType, valueType)
case Some(ArrayType(elementType, _)) =>
mergeWinner(elementAt.dataType, elementType)
case _ =>
None
}

case struct: CreateNamedStruct =>
val childrenContexts = struct.valExprs.map(findCollationContext)
if (childrenContexts.isEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.parser

import org.apache.spark.sql.catalyst.expressions.{Expression, Literal}
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.util.SchemaUtils

/**
* Handler for parameter substitution across different Spark SQL contexts.
Expand Down Expand Up @@ -107,14 +109,37 @@ object ParameterHandler {
* @param expr The expression to convert (must be a Literal)
* @return SQL string representation
*/
private def convertToSql(expr: Expression): String = expr match {
case lit: Literal => lit.sql
case other =>
throw new IllegalArgumentException(
s"ParameterHandler only accepts resolved Literal expressions. " +
s"Received: ${other.getClass.getSimpleName}. " +
s"All parameters must be resolved using SparkSession.resolveAndValidateParameters " +
s"before being passed to the pre-parser.")
private def convertToSql(expr: Expression): String = {
// Converts an expression to its SQL representation. If the expression's type contains collated
// types, strips collations from nested literals and wraps the whole expression in
// CAST to preserve the collation with implicit strength. Without this, Literal.sql
// produces `'value' COLLATE collationName` which re-parses with explicit strength.
def toSqlWithImplicitCollation(e: Expression): String = {
if (!DataTypeUtils.hasNonDefaultStringCharOrVarcharType(e.dataType)) {
e.sql
} else {
val stripped = e.transform {
case lit: Literal
if DataTypeUtils.hasNonDefaultStringCharOrVarcharType(lit.dataType) =>
Literal.create(
lit.value, SchemaUtils.replaceCollatedStringWithString(lit.dataType))
}
s"CAST(${stripped.sql} AS ${e.dataType.sql})"
}
}

expr match {
case lit: Literal if lit.value == null =>
lit.sql
case lit: Literal =>
toSqlWithImplicitCollation(lit)
case other =>
throw new IllegalArgumentException(
s"ParameterHandler only accepts resolved Literal expressions. " +
s"Received: ${other.getClass.getSimpleName}. " +
s"All parameters must be resolved using SparkSession.resolveAndValidateParameters " +
s"before being passed to the pre-parser.")
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,17 @@ object DataTypeUtils {
}
}

/**
* Returns true if the given data type contains any STRING/CHAR/VARCHAR with explicit collation
* (including explicit `UTF8_BINARY`), recursively checking nested types.
*/
def hasNonDefaultStringCharOrVarcharType(dataType: DataType): Boolean = {
dataType.existsRecursively {
case st: StringType => !isDefaultStringCharOrVarcharType(st)
case _ => false
}
}

/**
* Recursively replaces all STRING, CHAR and VARCHAR types that do not have an explicit collation
* with the same type but with explicit `UTF8_BINARY` collation.
Expand Down
Loading