Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,21 +37,23 @@ object ResolveMergeIntoSchemaEvolution extends Rule[LogicalPlan] {
case m @ MergeIntoTable(_, _, _, _, _, _, _)
if m.needSchemaEvolution =>
val newTarget = m.targetTable.transform {
case r : DataSourceV2Relation => performSchemaEvolution(r, m.sourceTable)
case r : DataSourceV2Relation => performSchemaEvolution(r, m)
}
m.copy(targetTable = newTarget)
}

private def performSchemaEvolution(relation: DataSourceV2Relation, source: LogicalPlan)
private def performSchemaEvolution(relation: DataSourceV2Relation, m: MergeIntoTable)
: DataSourceV2Relation = {
(relation.catalog, relation.identifier) match {
case (Some(c: TableCatalog), Some(i)) =>
val changes = MergeIntoTable.schemaChanges(relation.schema, source.schema)
val referencedSourceSchema = MergeIntoTable.sourceSchemaForSchemaEvolution(m)

val changes = MergeIntoTable.schemaChanges(relation.schema, referencedSourceSchema)
c.alterTable(i, changes: _*)
val newTable = c.loadTable(i)
val newSchema = CatalogV2Util.v2ColumnsToStructType(newTable.columns())
// Check if there are any remaining changes not applied.
val remainingChanges = MergeIntoTable.schemaChanges(newSchema, source.schema)
val remainingChanges = MergeIntoTable.schemaChanges(newSchema, referencedSourceSchema)
if (remainingChanges.nonEmpty) {
throw QueryCompilationErrors.unsupportedTableChangesInAutoSchemaEvolutionError(
remainingChanges, i.toQualifiedNameParts(c))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.{SparkIllegalArgumentException, SparkUnsupportedOperationException}
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, AssignmentUtils, EliminateSubqueryAliases, FieldName, NamedRelation, PartitionSpec, ResolvedIdentifier, ResolvedProcedure, TypeCheckResult, UnresolvedException, UnresolvedProcedure, ViewSchemaMode}
import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, AssignmentUtils, EliminateSubqueryAliases, FieldName, NamedRelation, PartitionSpec, ResolvedIdentifier, ResolvedProcedure, TypeCheckResult, UnresolvedAttribute, UnresolvedException, UnresolvedProcedure, ViewSchemaMode}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.catalog.{FunctionResource, RoutineLanguage}
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
Expand Down Expand Up @@ -860,7 +860,8 @@ case class MergeIntoTable(
matchedActions: Seq[MergeAction],
notMatchedActions: Seq[MergeAction],
notMatchedBySourceActions: Seq[MergeAction],
withSchemaEvolution: Boolean) extends BinaryCommand with SupportsSubquery {
withSchemaEvolution: Boolean)
extends BinaryCommand with SupportsSubquery {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unnecessary change


lazy val aligned: Boolean = {
val actions = matchedActions ++ notMatchedActions ++ notMatchedBySourceActions
Expand Down Expand Up @@ -892,9 +893,13 @@ case class MergeIntoTable(
case _ => false
}

lazy val needSchemaEvolution: Boolean =
private lazy val sourceSchemaForEvolution: StructType =
MergeIntoTable.sourceSchemaForSchemaEvolution(this)

lazy val needSchemaEvolution: Boolean = {
schemaEvolutionEnabled &&
MergeIntoTable.schemaChanges(targetTable.schema, sourceTable.schema).nonEmpty
MergeIntoTable.schemaChanges(targetTable.schema, sourceSchemaForEvolution).nonEmpty
}

private def schemaEvolutionEnabled: Boolean = withSchemaEvolution && {
EliminateSubqueryAliases(targetTable) match {
Expand All @@ -911,6 +916,7 @@ case class MergeIntoTable(
}

object MergeIntoTable {

def getWritePrivileges(
matchedActions: Iterable[MergeAction],
notMatchedActions: Iterable[MergeAction],
Expand Down Expand Up @@ -990,6 +996,79 @@ object MergeIntoTable {
CaseInsensitiveMap(fieldMap)
}
}

// A pruned version of source schema that only contains columns/nested fields
// explicitly and directly assigned to a target counterpart in MERGE INTO actions.
// New columns/nested fields not existing in target will be added for schema evolution.
def sourceSchemaForSchemaEvolution(merge: MergeIntoTable): StructType = {

val actions = merge.matchedActions ++ merge.notMatchedActions
val assignments = actions.collect {
case a: UpdateAction => a.assignments
case a: InsertAction => a.assignments
}.flatten

val containsStarAction = actions.exists {
case _: UpdateStarAction => true
case _: InsertStarAction => true
case _ => false
}

def filterSchema(sourceSchema: StructType, basePath: Seq[String]): StructType =
StructType(sourceSchema.flatMap { field =>
val fieldPath = basePath :+ field.name

field.dataType match {
// Specifically assigned to in one clause:
// always keep, including all nested attributes
case _ if assignments.exists(isEqual(_, fieldPath)) => Some(field)
// If this is a struct and one of the children is being assigned to in a merge clause,
// keep it and continue filtering children.
case struct: StructType if assignments.exists(assign =>
isPrefix(fieldPath, extractFieldPath(assign.key))) =>
Some(field.copy(dataType = filterSchema(struct, fieldPath)))
// The field isn't assigned to directly or indirectly (i.e. its children) in any non-*
// clause. Check if it should be kept with any * action.
case struct: StructType if containsStarAction =>
Some(field.copy(dataType = filterSchema(struct, fieldPath)))
case _ if containsStarAction => Some(field)
// The field and its children are not assigned to in any * or non-* action, drop it.
case _ => None
}
})

filterSchema(merge.sourceTable.schema, Seq.empty)
}

// Helper method to extract field path from an Expression.
private def extractFieldPath(expr: Expression): Seq[String] = expr match {
case UnresolvedAttribute(nameParts) => nameParts
case a: AttributeReference => Seq(a.name)
case GetStructField(child, ordinal, nameOpt) =>
extractFieldPath(child) :+ nameOpt.getOrElse(s"col$ordinal")
case _ => Seq.empty
}

// Helper method to check if a given field path is a prefix of another path.
private def isPrefix(prefix: Seq[String], path: Seq[String]): Boolean =
prefix.length <= path.length && prefix.zip(path).forall {
case (prefixNamePart, pathNamePart) =>
SQLConf.get.resolver(prefixNamePart, pathNamePart)
}

// Helper method to check if a given field path is a suffix of another path.
private def isSuffix(suffix: Seq[String], path: Seq[String]): Boolean =
isPrefix(suffix.reverse, path.reverse)

// Helper method to check if an assignment key is equal to a source column
// and if the assignment value is the corresponding source column directly
private def isEqual(assignment: Assignment, path: Seq[String]): Boolean = {
val assignmenKeyExpr = extractFieldPath(assignment.key)
val assignmentValueExpr = extractFieldPath(assignment.value)
// Valid assignments are: col = s.col or col.nestedField = s.col.nestedField
assignmenKeyExpr.length == path.length && isPrefix(assignmenKeyExpr, path) &&
isSuffix(path, assignmentValueExpr)
}
}

sealed abstract class MergeAction extends Expression with Unevaluable {
Expand Down
Loading