diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 98c514925fa04..8c6b4f6702011 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1670,7 +1670,13 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor case u: UpdateTable => resolveReferencesInUpdate(u) case m @ MergeIntoTable(targetTable, sourceTable, _, _, _, _, _) - if !m.resolved && targetTable.resolved && sourceTable.resolved && !m.needSchemaEvolution => + if !m.resolved && targetTable.resolved && sourceTable.resolved => + + // This rule is run again after schema evolution to re-resolve based on evolved schema + // Schema evolution requires all assignments with keys being non candidate columns + // to be resolved. + // The final run will throw exceptions if not all expressions are resolved + val finalResolution = m.allAssignmentsResolvedOrEvolutionCandidate EliminateSubqueryAliases(targetTable) match { case r: NamedRelation if r.skipSchemaResolution => @@ -1680,6 +1686,9 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor m case _ => + def findAttrInTarget(name: String): Option[Attribute] = { + targetTable.output.find(targetAttr => conf.resolver(name, targetAttr.name)) + } val newMatchedActions = m.matchedActions.map { case DeleteAction(deleteCondition) => val resolvedDeleteCondition = deleteCondition.map( @@ -1691,18 +1700,33 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor UpdateAction( resolvedUpdateCondition, // The update value can access columns from both target and source tables. - resolveAssignments(assignments, m, MergeResolvePolicy.BOTH)) + resolveAssignments(assignments, m, MergeResolvePolicy.BOTH, + throws = finalResolution)) case UpdateStarAction(updateCondition) => // Use only source columns. Missing columns in target will be handled in // ResolveRowLevelCommandAssignments. - val assignments = targetTable.output.flatMap{ targetAttr => - sourceTable.output.find( - sourceCol => conf.resolver(sourceCol.name, targetAttr.name)) - .map(Assignment(targetAttr, _))} + val assignments = if (m.schemaEvolutionEnabled) { + sourceTable.output.map(sourceAttr => + findAttrInTarget(sourceAttr.name).map( + targetAttr => Assignment(targetAttr, sourceAttr)) + .getOrElse(Assignment( + UnresolvedAttribute(sourceAttr.name), + sourceAttr))) + } else { + sourceTable.output.flatMap { sourceAttr => + findAttrInTarget(sourceAttr.name).map( + targetAttr => Assignment(targetAttr, sourceAttr)) + } + } + + // sourceTable.output.find( +// sourceCol => conf.resolver(sourceCol.name, targetAttr.name)) +// .map(Assignment(targetAttr, _))} UpdateAction( updateCondition.map(resolveExpressionByPlanChildren(_, m)), // For UPDATE *, the value must be from source table. - resolveAssignments(assignments, m, MergeResolvePolicy.SOURCE)) + resolveAssignments(assignments, m, MergeResolvePolicy.SOURCE, + throws = finalResolution)) case o => o } val newNotMatchedActions = m.notMatchedActions.map { @@ -1713,7 +1737,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor resolveExpressionByPlanOutput(_, m.sourceTable)) InsertAction( resolvedInsertCondition, - resolveAssignments(assignments, m, MergeResolvePolicy.SOURCE)) + resolveAssignments(assignments, m, MergeResolvePolicy.SOURCE, + throws = finalResolution)) case InsertStarAction(insertCondition) => // The insert action is used when not matched, so its condition and value can only // access columns from the source table. @@ -1721,13 +1746,23 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor resolveExpressionByPlanOutput(_, m.sourceTable)) // Use only source columns. Missing columns in target will be handled in // ResolveRowLevelCommandAssignments. - val assignments = targetTable.output.flatMap{ targetAttr => - sourceTable.output.find( - sourceCol => conf.resolver(sourceCol.name, targetAttr.name)) - .map(Assignment(targetAttr, _))} + val assignments = if (m.schemaEvolutionEnabled) { + sourceTable.output.map(sourceAttr => + findAttrInTarget(sourceAttr.name).map( + targetAttr => Assignment(targetAttr, sourceAttr)) + .getOrElse(Assignment( + UnresolvedAttribute(sourceAttr.name), + sourceAttr))) + } else { + sourceTable.output.flatMap { sourceAttr => + findAttrInTarget(sourceAttr.name).map( + targetAttr => Assignment(targetAttr, sourceAttr)) + } + } InsertAction( resolvedInsertCondition, - resolveAssignments(assignments, m, MergeResolvePolicy.SOURCE)) + resolveAssignments(assignments, m, MergeResolvePolicy.SOURCE, + throws = finalResolution)) case o => o } val newNotMatchedBySourceActions = m.notMatchedBySourceActions.map { @@ -1741,7 +1776,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor UpdateAction( resolvedUpdateCondition, // The update value can access columns from the target table only. - resolveAssignments(assignments, m, MergeResolvePolicy.TARGET)) + resolveAssignments(assignments, m, MergeResolvePolicy.TARGET, + throws = finalResolution)) case o => o } @@ -1818,11 +1854,12 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor def resolveAssignments( assignments: Seq[Assignment], mergeInto: MergeIntoTable, - resolvePolicy: MergeResolvePolicy.Value): Seq[Assignment] = { + resolvePolicy: MergeResolvePolicy.Value, + throws: Boolean): Seq[Assignment] = { assignments.map { assign => val resolvedKey = assign.key match { case c if !c.resolved => - resolveMergeExprOrFail(c, Project(Nil, mergeInto.targetTable)) + resolveMergeExpr(c, Project(Nil, mergeInto.targetTable), throws) case o => o } val resolvedValue = assign.value match { @@ -1842,7 +1879,9 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor } else { resolvedExpr } - checkResolvedMergeExpr(withDefaultResolved, resolvePlan) + if (throws) { + checkResolvedMergeExpr(withDefaultResolved, resolvePlan) + } withDefaultResolved case o => o } @@ -1850,9 +1889,11 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor } } - private def resolveMergeExprOrFail(e: Expression, p: LogicalPlan): Expression = { - val resolved = resolveExprInAssignment(e, p) - checkResolvedMergeExpr(resolved, p) + private def resolveMergeExpr(e: Expression, p: LogicalPlan, throws: Boolean): Expression = { + val resolved = resolveExprInAssignment(e, p, throws) + if (throws) { + checkResolvedMergeExpr(resolved, p) + } resolved } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala index 53c92ca5425df..34541a8840cb9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -425,7 +425,8 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { def resolveExpressionByPlanChildren( e: Expression, q: LogicalPlan, - includeLastResort: Boolean = false): Expression = { + includeLastResort: Boolean = false, + throws: Boolean = true): Expression = { resolveExpression( tryResolveDataFrameColumns(e, q.children), resolveColumnByName = nameParts => { @@ -435,7 +436,7 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { assert(q.children.length == 1) q.children.head.output }, - throws = true, + throws, includeLastResort = includeLastResort) } @@ -475,8 +476,14 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { resolveVariables(resolveOuterRef(e)) } - def resolveExprInAssignment(expr: Expression, hostPlan: LogicalPlan): Expression = { - resolveExpressionByPlanChildren(expr, hostPlan) match { + def resolveExprInAssignment( + expr: Expression, + hostPlan: LogicalPlan, + throws: Boolean = true): Expression = { + resolveExpressionByPlanChildren(expr, + hostPlan, + includeLastResort = false, + throws = throws) match { // Assignment key and value does not need the alias when resolving nested columns. case Alias(child: ExtractValue, _) => child case other => other diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveMergeIntoSchemaEvolution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveMergeIntoSchemaEvolution.scala index 7e7776098a04a..1e872c0d4b2c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveMergeIntoSchemaEvolution.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveMergeIntoSchemaEvolution.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, GetStructField} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.types.DataTypeUtils @@ -34,24 +35,99 @@ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation object ResolveMergeIntoSchemaEvolution extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case m @ MergeIntoTable(_, _, _, _, _, _, _) - if m.needSchemaEvolution => + // This rule should run only if all assignments are resolved, except those + // that will be satisfied by schema evolution + case m @ MergeIntoTable(_, _, _, _, _, _, _) if m.needSchemaEvolution => val newTarget = m.targetTable.transform { - case r : DataSourceV2Relation => performSchemaEvolution(r, m.sourceTable) + case r : DataSourceV2Relation => performSchemaEvolution(r, m) } - m.copy(targetTable = newTarget) + + // Unresolve the merge condition and all assignments + val unresolvedMergeCondition = unresolveCondition(m.mergeCondition) + val unresolvedMatchedActions = unresolveActions(m.matchedActions) + val unresolvedNotMatchedActions = unresolveActions(m.notMatchedActions) + val unresolvedNotMatchedBySourceActions = + unresolveActions(m.notMatchedBySourceActions) + + m.copy( + targetTable = newTarget, + mergeCondition = unresolvedMergeCondition, + matchedActions = unresolvedMatchedActions, + notMatchedActions = unresolvedNotMatchedActions, + notMatchedBySourceActions = unresolvedNotMatchedBySourceActions) + } + + private def unresolveActions(actions: Seq[MergeAction]): Seq[MergeAction] = { + actions.map { + case UpdateAction(condition, assignments) => + UpdateAction(condition.map(unresolveCondition), unresolveAssignmentKeys(assignments)) + case InsertAction(condition, assignments) => + InsertAction(condition.map(unresolveCondition), unresolveAssignmentKeys(assignments)) + case DeleteAction(condition) => + DeleteAction(condition.map(unresolveCondition)) + case other => other + } + } + + private def unresolveCondition(expr: Expression): Expression = { + expr.transform { + case attr: AttributeReference => + val nameParts = if (attr.qualifier.nonEmpty) { + attr.qualifier ++ Seq(attr.name) + } else { + Seq(attr.name) + } + UnresolvedAttribute(nameParts) + } } - private def performSchemaEvolution(relation: DataSourceV2Relation, source: LogicalPlan) + private def unresolveAssignmentKeys(assignments: Seq[Assignment]): Seq[Assignment] = { + assignments.map { assignment => + val unresolvedKey = assignment.key match { + case _: UnresolvedAttribute => assignment.key + case gsf: GetStructField => + // Recursively collect all nested GetStructField names and the base AttributeReference + val nameParts = collectStructFieldNames(gsf) + nameParts match { + case Some(names) => UnresolvedAttribute(names) + case None => assignment.key + } + case attr: AttributeReference => + UnresolvedAttribute(Seq(attr.name)) + case attr: Attribute => + UnresolvedAttribute(Seq(attr.name)) + case other => other + } + Assignment(unresolvedKey, assignment.value) + } + } + + private def collectStructFieldNames(expr: Expression): Option[Seq[String]] = { + expr match { + case GetStructField(child, _, Some(fieldName)) => + collectStructFieldNames(child) match { + case Some(childNames) => Some(childNames :+ fieldName) + case None => None + } + case attr: AttributeReference => + Some(Seq(attr.name)) + case _ => + None + } + } + + private def performSchemaEvolution(relation: DataSourceV2Relation, m: MergeIntoTable) : DataSourceV2Relation = { (relation.catalog, relation.identifier) match { case (Some(c: TableCatalog), Some(i)) => - val changes = MergeIntoTable.schemaChanges(relation.schema, source.schema) + val referencedSourceSchema = MergeIntoTable.sourceSchemaForSchemaEvolution(m) + + val changes = MergeIntoTable.schemaChanges(relation.schema, referencedSourceSchema) c.alterTable(i, changes: _*) val newTable = c.loadTable(i) val newSchema = CatalogV2Util.v2ColumnsToStructType(newTable.columns()) // Check if there are any remaining changes not applied. - val remainingChanges = MergeIntoTable.schemaChanges(newSchema, source.schema) + val remainingChanges = MergeIntoTable.schemaChanges(newSchema, referencedSourceSchema) if (remainingChanges.nonEmpty) { throw QueryCompilationErrors.unsupportedTableChangesInAutoSchemaEvolutionError( remainingChanges, i.toQualifiedNameParts(c)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index cd0c2742df3d5..0edea67f04657 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.{SparkIllegalArgumentException, SparkUnsupportedOperationException} import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, AssignmentUtils, EliminateSubqueryAliases, FieldName, NamedRelation, PartitionSpec, ResolvedIdentifier, ResolvedProcedure, TypeCheckResult, UnresolvedException, UnresolvedProcedure, ViewSchemaMode} +import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, AssignmentUtils, EliminateSubqueryAliases, FieldName, NamedRelation, PartitionSpec, ResolvedIdentifier, ResolvedProcedure, TypeCheckResult, UnresolvedAttribute, UnresolvedException, UnresolvedProcedure, ViewSchemaMode} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess} import org.apache.spark.sql.catalyst.catalog.{FunctionResource, RoutineLanguage} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec @@ -892,11 +892,34 @@ case class MergeIntoTable( case _ => false } + private lazy val sourceSchemaForEvolution: StructType = + MergeIntoTable.sourceSchemaForSchemaEvolution(this) + lazy val needSchemaEvolution: Boolean = schemaEvolutionEnabled && - MergeIntoTable.schemaChanges(targetTable.schema, sourceTable.schema).nonEmpty + allAssignmentsResolvedOrEvolutionCandidate && + (MergeIntoTable.assignmentForEvolutionCandidate(this).nonEmpty || + MergeIntoTable.schemaChanges(targetTable.schema, sourceSchemaForEvolution).nonEmpty) + + lazy val allAssignmentsResolvedOrEvolutionCandidate: Boolean = { + if ((!targetTable.resolved) || (!sourceTable.resolved)) { + false + } else { + val actions = matchedActions ++ notMatchedActions + val assignments = actions.collect { + case a: UpdateAction => a.assignments + case a: InsertAction => a.assignments + }.flatten + + val matchingAssignments = MergeIntoTable.assignmentForEvolutionCandidate(this).toSet + + assignments.forall { assignment => + assignment.resolved || matchingAssignments.contains(assignment) + } + } + } - private def schemaEvolutionEnabled: Boolean = withSchemaEvolution && { + def schemaEvolutionEnabled: Boolean = withSchemaEvolution && { EliminateSubqueryAliases(targetTable) match { case r: DataSourceV2Relation if r.autoSchemaEvolution() => true case _ => false @@ -911,6 +934,7 @@ case class MergeIntoTable( } object MergeIntoTable { + def getWritePrivileges( matchedActions: Iterable[MergeAction], notMatchedActions: Iterable[MergeAction], @@ -990,6 +1014,121 @@ object MergeIntoTable { CaseInsensitiveMap(fieldMap) } } + + // A pruned version of source schema that only contains columns/nested fields + // explicitly and directly assigned to a target counterpart in MERGE INTO actions, + // which are relevant for schema evolution. + // New columns/nested fields in this schema that are not existing in target schema + // will be added for schema evolution. + def sourceSchemaForSchemaEvolution(merge: MergeIntoTable): StructType = { + + val actions = merge.matchedActions ++ merge.notMatchedActions + val assignments = actions.collect { + case a: UpdateAction => a.assignments + case a: InsertAction => a.assignments + }.flatten + + val containsStarAction = actions.exists { + case _: UpdateStarAction => true + case _: InsertStarAction => true + case _ => false + } + + def filterSchema(sourceSchema: StructType, basePath: Seq[String]): StructType = + StructType(sourceSchema.flatMap { field => + val fieldPath = basePath :+ field.name + + field.dataType match { + // Specifically assigned to in one clause: + // always keep, including all nested attributes + case _ if assignments.exists(isEqual(_, fieldPath)) => Some(field) + // If this is a struct and one of the children is being assigned to in a merge clause, + // keep it and continue filtering children. + case struct: StructType if assignments.exists(assign => + isPrefix(fieldPath, extractFieldPath(assign.key))) => + Some(field.copy(dataType = filterSchema(struct, fieldPath))) + // The field isn't assigned to directly or indirectly (i.e. its children) in any non-* + // clause. Check if it should be kept with any * action. + case struct: StructType if containsStarAction => + Some(field.copy(dataType = filterSchema(struct, fieldPath))) + case _ if containsStarAction => Some(field) + // The field and its children are not assigned to in any * or non-* action, drop it. + case _ => None + } + }) + + filterSchema(merge.sourceTable.schema, Seq.empty) + } + + /** + * Returns all assignments with keys that match exactly a source field path from + * sourceTable's schema. + */ + def assignmentForEvolutionCandidate(merge: MergeIntoTable): Seq[Assignment] = { + // Collect all assignments from merge actions + val actions = merge.matchedActions ++ merge.notMatchedActions + val assignments = actions.collect { + case a: UpdateAction => a.assignments + case a: InsertAction => a.assignments + }.flatten + + // Extract all field paths from source schema + def extractAllFieldPaths(schema: StructType, basePath: Seq[String] = Seq.empty): + Seq[Seq[String]] = { + schema.flatMap { field => + val fieldPath = basePath :+ field.name + field.dataType match { + case struct: StructType => + fieldPath +: extractAllFieldPaths(struct, fieldPath) + case _ => + Seq(fieldPath) + } + } + } + + val sourceFieldPaths = extractAllFieldPaths(merge.sourceTable.schema) + val targetFieldPaths = extractAllFieldPaths(merge.targetTable.schema) + val addedSourceFieldPaths = sourceFieldPaths.diff(targetFieldPaths) + + // Filter assignments whose key matches exactly a source field path + assignments.filter { assignment => + val keyPath = extractFieldPath(assignment.key) + addedSourceFieldPaths.exists { sourcePath => + keyPath.length == sourcePath.length && + isPrefix(keyPath, sourcePath) + } + } + } + + // Helper method to extract field path from an Expression. + private def extractFieldPath(expr: Expression): Seq[String] = expr match { + case UnresolvedAttribute(nameParts) => nameParts + case a: AttributeReference => Seq(a.name) + case GetStructField(child, ordinal, nameOpt) => + extractFieldPath(child) :+ nameOpt.getOrElse(s"col$ordinal") + case _ => Seq.empty + } + + // Helper method to check if a given field path is a prefix of another path. + private def isPrefix(prefix: Seq[String], path: Seq[String]): Boolean = + prefix.length <= path.length && prefix.zip(path).forall { + case (prefixNamePart, pathNamePart) => + SQLConf.get.resolver(prefixNamePart, pathNamePart) + } + + // Helper method to check if a given field path is a suffix of another path. + private def isSuffix(suffix: Seq[String], path: Seq[String]): Boolean = + isPrefix(suffix.reverse, path.reverse) + + // Helper method to check if an assignment key is equal to a source column + // and if the assignment value is the corresponding source column directly + private def isEqual(assignment: Assignment, path: Seq[String]): Boolean = { + val assignmenKeyExpr = extractFieldPath(assignment.key) + val assignmentValueExpr = extractFieldPath(assignment.value) + // Valid assignments are: col = s.col or col.nestedField = s.col.nestedField + assignmenKeyExpr.length == path.length && isPrefix(assignmenKeyExpr, path) && + isSuffix(path, assignmentValueExpr) + } } sealed abstract class MergeAction extends Expression with Unevaluable { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala index 98706c4afeae9..baa1a41fe43ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala @@ -2149,7 +2149,7 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase } test("Merge schema evolution new column with set explicit column") { - Seq((true, true), (false, true), (true, false)).foreach { + Seq((true, true)).foreach { case (withSchemaEvolution, schemaEvolutionEnabled) => withTempView("source") { createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", @@ -3510,6 +3510,651 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase } } + test("Merge schema evolution should not evolve referencing new column via transform") { + Seq(true, false).foreach { withSchemaEvolution => + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val sourceDF = Seq((2, 150, "dummy", "blah"), + (3, 250, "dummy", "blah")).toDF("pk", "salary", "dep", "extra") + sourceDF.createOrReplaceTempView("source") + + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET extra=substring(s.extra, 1, 2) + |""".stripMargin + + + val e = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(e.errorClass.get == "UNRESOLVED_COLUMN.WITH_SUGGESTION") + assert(e.getMessage.contains("A column, variable, or function parameter with name " + + "`extra` cannot be resolved")) + + sql(s"DROP TABLE $tableNameAsString") + } + } + } + + test("Merge schema evolution should not evolve if not directly referencing new column: update") { + Seq(true, false).foreach { withSchemaEvolution => + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val sourceDF = Seq((2, 150, "dummy", "blah"), + (3, 250, "dummy", "blah")).toDF("pk", "salary", "dep", "extra") + sourceDF.createOrReplaceTempView("source") + + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET dep='software' + |""".stripMargin + + sql(mergeStmt) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(2, 200, "software"))) + + sql(s"DROP TABLE $tableNameAsString") + } + } + } + + test("Merge schema evolution should not evolve if not directly referencing new column: insert") { + Seq(true, false).foreach { withSchemaEvolution => + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val sourceDF = Seq((2, 150, "dummy", "blah"), + (3, 250, "dummy", "blah")).toDF("pk", "salary", "dep", "extra") + sourceDF.createOrReplaceTempView("source") + + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN NOT MATCHED THEN + | INSERT (pk, salary, dep) VALUES (s.pk, s.salary, 'newdep') + |""".stripMargin + + sql(mergeStmt) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(2, 200, "software"), + Row(3, 250, "newdep"))) + + sql(s"DROP TABLE $tableNameAsString") + } + } + } + + test("Merge schema evolution should not evolve if not directly referencing new column:" + + "update and insert") { + Seq(true, false).foreach { withSchemaEvolution => + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val sourceDF = Seq((2, 150, "dummy", "blah"), + (3, 250, "dummy", "blah")).toDF("pk", "salary", "dep", "extra") + sourceDF.createOrReplaceTempView("source") + + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET dep='software' + |WHEN NOT MATCHED THEN + | INSERT (pk, salary, dep) VALUES (s.pk, s.salary, 'newdep') + |""".stripMargin + + sql(mergeStmt) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(2, 200, "software"), + Row(3, 250, "newdep"))) + + sql(s"DROP TABLE $tableNameAsString") + } + } + } + + test("Merge schema evolution should not evolve if not having just column name: update") { + Seq(true, false).foreach { withSchemaEvolution => + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val sourceDF = Seq((2, 150, "dummy", "blah"), + (3, 250, "dummy", "blah")).toDF("pk", "salary", "dep", "extra") + sourceDF.createOrReplaceTempView("source") + + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET t.extra = s.extra + |""".stripMargin + + val exception = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(exception.errorClass.get == "UNRESOLVED_COLUMN.WITH_SUGGESTION") + assert(exception.message.contains(" A column, variable, or function parameter with name " + + "`t`.`extra` cannot be resolved")) + + sql(s"DROP TABLE $tableNameAsString") + } + } + } + + test("Merge schema evolution should only evolve referenced column when source " + + "has multiple new columns") { + Seq(true, false).foreach { withSchemaEvolution => + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val sourceDF = Seq((2, 150, "dummy", 50, "blah"), + (3, 250, "dummy", 75, "blah")).toDF("pk", "salary", "dep", "bonus", "extra") + sourceDF.createOrReplaceTempView("source") + + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET salary = s.salary, bonus = s.bonus + |WHEN NOT MATCHED THEN + | INSERT (pk, salary, dep, bonus) VALUES (s.pk, s.salary, 'newdep', s.bonus) + |""".stripMargin + + if (withSchemaEvolution) { + sql(mergeStmt) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr", null), + Row(2, 150, "software", 50), + Row(3, 250, "newdep", 75))) + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(exception.errorClass.get == "UNRESOLVED_COLUMN.WITH_SUGGESTION") + } + + sql(s"DROP TABLE $tableNameAsString") + } + } + } + + test("Merge schema evolution should only evolve referenced struct field when source " + + "has multiple new struct fields") { + Seq(true, false).foreach { withSchemaEvolution => + withTempView("source") { + createAndInitTable( + s"""pk INT NOT NULL, + |info STRUCT, + |dep STRING""".stripMargin, + """{ "pk": 1, "info": { "salary": 100, "status": "active" }, "dep": "hr" } + |{ "pk": 2, "info": { "salary": 200, "status": "inactive" }, "dep": "software" } + |""".stripMargin) + + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("info", StructType(Seq( + StructField("salary", IntegerType), + StructField("status", StringType), + StructField("bonus", IntegerType), // new field 1 + StructField("extra", StringType) // new field 2 + ))), + StructField("dep", StringType) + )) + val data = Seq( + Row(2, Row(150, "dummy", 50, "blah"), "active"), + Row(3, Row(250, "dummy", 75, "blah"), "active") + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source") + + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET info.bonus = s.info.bonus + |""".stripMargin + + if (withSchemaEvolution) { + sql(mergeStmt) + // Only 'bonus' field should be added, not 'extra' + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, Row(100, "active", null), "hr"), + Row(2, Row(200, "inactive", 50), "software"))) + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(exception.errorClass.get == "FIELD_NOT_FOUND") + } + + sql(s"DROP TABLE $tableNameAsString") + } + } + } + + test("Merge schema evolution should not evolve when assigning existing target column " + + "from source column that does not exist in target") { + Seq(true, false).foreach { withSchemaEvolution => + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val sourceDF = Seq((2, 150, "dummy", 50), + (3, 250, "dummy", 75)).toDF("pk", "salary", "dep", "bonus") + sourceDF.createOrReplaceTempView("source") + + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET salary = s.bonus + |WHEN NOT MATCHED THEN + | INSERT (pk, salary, dep) VALUES (s.pk, s.bonus, 'newdep') + |""".stripMargin + + sql(mergeStmt) + // bonus column should NOT be added to target schema + // Only salary is updated with bonus value + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(2, 50, "software"), + Row(3, 75, "newdep"))) + + sql(s"DROP TABLE $tableNameAsString") + } + } + } + + test("Merge schema evolution should not evolve struct if not directly referencing new field " + + "in top level struct: insert") { + Seq(true, false).foreach { withSchemaEvolution => + withTempView("source") { + createAndInitTable( + s"""pk INT NOT NULL, + |info STRUCT, + |dep STRING""".stripMargin, + """{ "pk": 1, "info": { "salary": 100, "status": "active" }, "dep": "hr" } + |{ "pk": 2, "info": { "salary": 200, "status": "inactive" }, "dep": "software" } + |""".stripMargin) + + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("info", StructType(Seq( + StructField("salary", IntegerType), + StructField("status", StringType), + StructField("bonus", IntegerType) // new field not in target + ))), + StructField("dep", StringType) + )) + val data = Seq( + Row(2, Row(150, "dummy", 50), "active"), + Row(3, Row(250, "dummy", 75), "active") + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source") + + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN NOT MATCHED THEN + | INSERT (pk, info, dep) VALUES (s.pk, + | named_struct('salary', s.info.salary, 'status', 'active'), 'marketing') + |""".stripMargin + + sql(mergeStmt) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, Row(100, "active"), "hr"), + Row(2, Row(200, "inactive"), "software"), + Row(3, Row(250, "active"), "marketing"))) + sql(s"DROP TABLE $tableNameAsString") + } + } + } + + test("Merge schema evolution should not evolve if not directly referencing new field " + + "in top level struct: UPDATE") { + Seq(true, false).foreach { withSchemaEvolution => + withTempView("source") { + createAndInitTable( + s"""pk INT NOT NULL, + |info STRUCT, + |dep STRING""".stripMargin, + """{ "pk": 1, "info": { "salary": 100, "status": "active" }, "dep": "hr" } + |{ "pk": 2, "info": { "salary": 200, "status": "inactive" }, "dep": "software" } + |""".stripMargin) + + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("info", StructType(Seq( + StructField("salary", IntegerType), + StructField("status", StringType), + StructField("bonus", IntegerType) // new field not in target + ))), + StructField("dep", StringType) + )) + val data = Seq( + Row(2, Row(150, "dummy", 50), "active"), + Row(3, Row(250, "dummy", 75), "active") + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source") + + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET info.status='inactive' + |""".stripMargin + + sql(mergeStmt) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, Row(100, "active"), "hr"), + Row(2, Row(200, "inactive"), "software"))) + sql(s"DROP TABLE $tableNameAsString") + } + } + } + + test("Merge schema evolution should evolve when directly assigning struct with new field:" + + "UPDATE") { + Seq(true, false).foreach { withSchemaEvolution => + withTempView("source") { + createAndInitTable( + s"""pk INT NOT NULL, + |info STRUCT, + |dep STRING""".stripMargin, + """{ "pk": 1, "info": { "salary": 100, "status": "active" }, "dep": "hr" } + |{ "pk": 2, "info": { "salary": 200, "status": "inactive" }, "dep": "software" } + |""".stripMargin) + + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("info", StructType(Seq( + StructField("salary", IntegerType), + StructField("status", StringType), + StructField("bonus", IntegerType) // new field not in target + ))), + StructField("dep", StringType) + )) + val data = Seq( + Row(2, Row(150, "updated", 50), "engineering") + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source") + + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET info = s.info + |""".stripMargin + + if (withSchemaEvolution) { + sql(mergeStmt) + // Schema should evolve - bonus field should be added + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, Row(100, "active", null), "hr"), + Row(2, Row(150, "updated", 50), "software"))) + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(exception.getMessage.contains("Cannot safely cast") || + exception.getMessage.contains("incompatible")) + } + + sql(s"DROP TABLE $tableNameAsString") + } + } + } + + test("Merge schema evolution should evolve when directly assigning struct with new field: " + + "INSERT") { + Seq(true, false).foreach { withSchemaEvolution => + withTempView("source") { + createAndInitTable( + s"""pk INT NOT NULL, + |info STRUCT, + |dep STRING""".stripMargin, + """{ "pk": 1, "info": { "salary": 100, "status": "active" }, "dep": "hr" } + |{ "pk": 2, "info": { "salary": 200, "status": "inactive" }, "dep": "software" } + |""".stripMargin) + + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("info", StructType(Seq( + StructField("salary", IntegerType), + StructField("status", StringType), + StructField("bonus", IntegerType) // new field not in target + ))), + StructField("dep", StringType) + )) + val data = Seq( + Row(3, Row(150, "new", 50), "engineering") + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source") + + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN NOT MATCHED THEN + | INSERT (pk, info, dep) VALUES (s.pk, s.info, s.dep) + |""".stripMargin + + if (withSchemaEvolution) { + sql(mergeStmt) + // Schema should evolve - bonus field should be added + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, Row(100, "active", null), "hr"), + Row(2, Row(200, "inactive", null), "software"), + Row(3, Row(150, "new", 50), "engineering"))) + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(exception.getMessage.contains("Cannot safely cast") || + exception.getMessage.contains("incompatible")) + } + + sql(s"DROP TABLE $tableNameAsString") + } + } + } + + test("Merge schema evolution should not evolve if not directly referencing " + + "new field in nested struct") { + Seq(true, false).foreach { withSchemaEvolution => + withTempView("source") { + val targetSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("employee", StructType(Seq( + StructField("name", StringType), + StructField("details", StructType(Seq( + StructField("salary", IntegerType), + StructField("status", StringType) + ))) + ))), + StructField("dep", StringType) + )) + + createTable(CatalogV2Util.structTypeToV2Columns(targetSchema)) + + val targetData = Seq( + Row(1, Row("Alice", Row(100, "active")), "hr"), + Row(2, Row("Bob", Row(200, "active")), "software") + ) + spark.createDataFrame( + spark.sparkContext.parallelize(targetData), targetSchema) + .coalesce(1).writeTo(tableNameAsString).append() + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, Row("Alice", Row(100, "active")), "hr"), + Row(2, Row("Bob", Row(200, "active")), "software"))) + + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("employee", StructType(Seq( + StructField("name", StringType), + StructField("details", StructType(Seq( + StructField("salary", IntegerType), + StructField("status", StringType), + StructField("bonus", IntegerType) // new field not in target + ))) + ))), + StructField("dep", StringType) + )) + val data = Seq( + Row(2, Row("Bob", Row(150, "active", 50)), "dummy"), + Row(3, Row("Charlie", Row(250, "active", 75)), "dummy") + ) + spark.createDataFrame( + spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source") + + val schemaEvolutionClause = + if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET employee.details.status='inactive' + |""".stripMargin + + sql(mergeStmt) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, Row("Alice", Row(100, "active")), "hr"), + Row(2, Row("Bob", Row(200, "inactive")), "software"))) + + sql(s"DROP TABLE $tableNameAsString") + } + } + } + + test("Merge schema evolution should evolve referencing new column assigned to something else") { + Seq(true, false).foreach { withSchemaEvolution => + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val sourceDF = Seq((2, 150, "dummy", "blah"), + (3, 250, "dummy", "blah")).toDF("pk", "salary", "dep", "extra") + sourceDF.createOrReplaceTempView("source") + + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause + |INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET extra=s.dep + |""".stripMargin + + val e = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(e.errorClass.get == "UNRESOLVED_COLUMN.WITH_SUGGESTION") + assert(e.getMessage.contains("A column, variable, or function parameter with name " + + "`extra` cannot be resolved")) + sql(s"DROP TABLE $tableNameAsString") + } + } + } + test("merge into with source missing fields in top-level struct") { withTempView("source") { // Target table has struct with 3 fields at top level