Skip to content

Commit 6c6de51

Browse files
committed
Refactor and add more test
1 parent 3daaf68 commit 6c6de51

File tree

4 files changed

+143
-70
lines changed

4 files changed

+143
-70
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1749,7 +1749,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
17491749
m.copy(mergeCondition = resolvedMergeCondition,
17501750
matchedActions = newMatchedActions,
17511751
notMatchedActions = newNotMatchedActions,
1752-
notMatchedBySourceActions = newNotMatchedBySourceActions)
1752+
notMatchedBySourceActions = newNotMatchedBySourceActions,
1753+
originalSourceActions = newMatchedActions ++ newNotMatchedActions)
17531754
}
17541755

17551756
// UnresolvedHaving can host grouping expressions and aggregate functions. We should resolve

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveRowLevelCommandAssignments.scala

Lines changed: 28 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -60,18 +60,11 @@ object ResolveRowLevelCommandAssignments extends Rule[LogicalPlan] {
6060
notMatchedActions = alignActions(m.targetTable.output, m.notMatchedActions,
6161
coerceNestedTypes),
6262
notMatchedBySourceActions = alignActions(m.targetTable.output, m.notMatchedBySourceActions,
63-
coerceNestedTypes),
64-
preservedSourceActions = Some(m.matchedActions ++ m.notMatchedActions)
65-
)
63+
coerceNestedTypes))
6664

6765
case m: MergeIntoTable if !m.skipSchemaResolution && m.resolved && !m.aligned
6866
&& !m.needSchemaEvolution =>
69-
m.copy(
70-
matchedActions = m.notMatchedActions.map(resolveMergeAction),
71-
notMatchedActions = m.notMatchedActions.map(resolveMergeAction),
72-
notMatchedBySourceActions = m.matchedActions.map(resolveMergeAction),
73-
preservedSourceActions = Some(m.matchedActions ++ m.notMatchedActions)
74-
)
67+
resolveAssignments(m)
7568
}
7669

7770
private def validateStoreAssignmentPolicy(): Unit = {
@@ -90,51 +83,33 @@ object ResolveRowLevelCommandAssignments extends Rule[LogicalPlan] {
9083

9184
private def resolveAssignments(p: LogicalPlan): LogicalPlan = {
9285
p.transformExpressions {
93-
case assignment: Assignment => resolveAssignment(assignment)
94-
}
95-
}
96-
97-
private def resolveMergeAction(mergeAction: MergeAction) = {
98-
mergeAction match {
99-
case u @ UpdateAction(_, assignments) =>
100-
u.copy(assignments = assignments.map(resolveAssignment))
101-
case i @ InsertAction(_, assignments) =>
102-
i.copy(assignments = assignments.map(resolveAssignment))
103-
case d: DeleteAction =>
104-
d
105-
case other =>
106-
throw new AnalysisException(
107-
errorClass = "_LEGACY_ERROR_TEMP_3053",
108-
messageParameters = Map("other" -> other.toString))
109-
}
110-
}
111-
112-
private def resolveAssignment(assignment: Assignment) = {
113-
val nullHandled = if (!assignment.key.nullable && assignment.value.nullable) {
114-
AssertNotNull(assignment.value)
115-
} else {
116-
assignment.value
117-
}
118-
val casted = if (assignment.key.dataType != nullHandled.dataType) {
119-
val cast = Cast(nullHandled, assignment.key.dataType, ansiEnabled = true)
120-
cast.setTagValue(Cast.BY_TABLE_INSERTION, ())
121-
cast
122-
} else {
123-
nullHandled
124-
}
125-
val rawKeyType = assignment.key.transform {
126-
case a: AttributeReference =>
127-
CharVarcharUtils.getRawType(a.metadata).map(a.withDataType).getOrElse(a)
128-
}.dataType
129-
val finalValue = if (CharVarcharUtils.hasCharVarchar(rawKeyType)) {
130-
CharVarcharUtils.stringLengthCheck(casted, rawKeyType)
131-
} else {
132-
casted
133-
}
134-
val cleanedKey = assignment.key.transform {
135-
case a: AttributeReference => CharVarcharUtils.cleanAttrMetadata(a)
86+
case assignment: Assignment =>
87+
val nullHandled = if (!assignment.key.nullable && assignment.value.nullable) {
88+
AssertNotNull(assignment.value)
89+
} else {
90+
assignment.value
91+
}
92+
val casted = if (assignment.key.dataType != nullHandled.dataType) {
93+
val cast = Cast(nullHandled, assignment.key.dataType, ansiEnabled = true)
94+
cast.setTagValue(Cast.BY_TABLE_INSERTION, ())
95+
cast
96+
} else {
97+
nullHandled
98+
}
99+
val rawKeyType = assignment.key.transform {
100+
case a: AttributeReference =>
101+
CharVarcharUtils.getRawType(a.metadata).map(a.withDataType).getOrElse(a)
102+
}.dataType
103+
val finalValue = if (CharVarcharUtils.hasCharVarchar(rawKeyType)) {
104+
CharVarcharUtils.stringLengthCheck(casted, rawKeyType)
105+
} else {
106+
casted
107+
}
108+
val cleanedKey = assignment.key.transform {
109+
case a: AttributeReference => CharVarcharUtils.cleanAttrMetadata(a)
110+
}
111+
Assignment(cleanedKey, finalValue)
136112
}
137-
Assignment(cleanedKey, finalValue)
138113
}
139114

140115
private def alignActions(

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -862,7 +862,7 @@ case class MergeIntoTable(
862862
notMatchedBySourceActions: Seq[MergeAction],
863863
withSchemaEvolution: Boolean,
864864
// Preserves original pre-aligned actions for source matches
865-
preservedSourceActions: Option[Seq[MergeAction]] = None)
865+
originalSourceActions: Seq[MergeAction])
866866
extends BinaryCommand with SupportsSubquery {
867867

868868
lazy val aligned: Boolean = {
@@ -895,12 +895,14 @@ case class MergeIntoTable(
895895
case _ => false
896896
}
897897

898-
private lazy val migrationSchema: StructType =
898+
// a pruned version of source schema that only contains columns/nested fields
899+
// explicitly assigned by MERGE INTO actions
900+
private lazy val referencedSourceSchema: StructType =
899901
MergeIntoTable.referencedSourceSchema(this)
900902

901903
lazy val needSchemaEvolution: Boolean = {
902904
schemaEvolutionEnabled &&
903-
MergeIntoTable.schemaChanges(targetTable.schema, migrationSchema).nonEmpty
905+
MergeIntoTable.schemaChanges(targetTable.schema, referencedSourceSchema).nonEmpty
904906
}
905907

906908
private def schemaEvolutionEnabled: Boolean = withSchemaEvolution && {
@@ -918,6 +920,26 @@ case class MergeIntoTable(
918920
}
919921

920922
object MergeIntoTable {
923+
924+
def apply(
925+
targetTable: LogicalPlan,
926+
sourceTable: LogicalPlan,
927+
mergeCondition: Expression,
928+
matchedActions: Seq[MergeAction],
929+
notMatchedActions: Seq[MergeAction],
930+
notMatchedBySourceActions: Seq[MergeAction],
931+
withSchemaEvolution: Boolean): MergeIntoTable = {
932+
MergeIntoTable(
933+
targetTable,
934+
sourceTable,
935+
mergeCondition,
936+
matchedActions,
937+
notMatchedActions,
938+
notMatchedBySourceActions,
939+
withSchemaEvolution,
940+
matchedActions ++ notMatchedActions)
941+
}
942+
921943
def getWritePrivileges(
922944
matchedActions: Iterable[MergeAction],
923945
notMatchedActions: Iterable[MergeAction],
@@ -955,12 +977,11 @@ object MergeIntoTable {
955977
case currentField: StructField if newFieldMap.contains(currentField.name) =>
956978
schemaChanges(currentField.dataType, newFieldMap(currentField.name).dataType,
957979
originalTarget, originalSource, fieldPath ++ Seq(currentField.name))
958-
}
959-
}.flatten
980+
}}.flatten
960981

961982
// Identify the newly added fields and append to the end
962983
val currentFieldMap = toFieldMap(currentFields)
963-
val adds = newFields.filterNot(f => currentFieldMap.contains(f.name))
984+
val adds = newFields.filterNot (f => currentFieldMap.contains(f.name))
964985
.map(f => TableChange.addColumn(fieldPath ++ Set(f.name), f.dataType))
965986

966987
updates ++ adds
@@ -1003,17 +1024,12 @@ object MergeIntoTable {
10031024
// by at least one merge action
10041025
def referencedSourceSchema(merge: MergeIntoTable): StructType = {
10051026

1006-
val actions = merge.preservedSourceActions match {
1007-
case Some(preserved) => preserved
1008-
case None => merge.matchedActions ++ merge.notMatchedActions
1009-
}
1010-
1011-
val assignments = actions.collect {
1027+
val assignments = merge.originalSourceActions.collect {
10121028
case a: UpdateAction => a.assignments.map(_.key)
10131029
case a: InsertAction => a.assignments.map(_.key)
10141030
}.flatten
10151031

1016-
val containsStarAction = actions.exists {
1032+
val containsStarAction = merge.originalSourceActions.exists {
10171033
case _: UpdateStarAction => true
10181034
case _: InsertStarAction => true
10191035
case _ => false
@@ -1042,8 +1058,6 @@ object MergeIntoTable {
10421058
}
10431059
})
10441060

1045-
val sourceSchema = merge.sourceTable.schema
1046-
val targetSchema = merge.targetTable.schema
10471061
val res = filterSchema(merge.sourceTable.schema, Seq.empty)
10481062
res
10491063
}
@@ -1072,7 +1086,6 @@ object MergeIntoTable {
10721086
}
10731087
}
10741088

1075-
10761089
sealed abstract class MergeAction extends Expression with Unevaluable {
10771090
def condition: Option[Expression]
10781091
override def nullable: Boolean = false

sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3721,6 +3721,46 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
37213721
}
37223722
}
37233723

3724+
test("Merge schema evolution should not evolve when assigning existing target column " +
3725+
"from source column that does not exist in target") {
3726+
Seq(true, false).foreach { withSchemaEvolution =>
3727+
withTempView("source") {
3728+
createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
3729+
"""{ "pk": 1, "salary": 100, "dep": "hr" }
3730+
|{ "pk": 2, "salary": 200, "dep": "software" }
3731+
|""".stripMargin)
3732+
3733+
val sourceDF = Seq((2, 150, "dummy", 50),
3734+
(3, 250, "dummy", 75)).toDF("pk", "salary", "dep", "bonus")
3735+
sourceDF.createOrReplaceTempView("source")
3736+
3737+
val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else ""
3738+
val mergeStmt =
3739+
s"""MERGE $schemaEvolutionClause
3740+
|INTO $tableNameAsString t
3741+
|USING source s
3742+
|ON t.pk = s.pk
3743+
|WHEN MATCHED THEN
3744+
| UPDATE SET salary = s.bonus
3745+
|WHEN NOT MATCHED THEN
3746+
| INSERT (pk, salary, dep) VALUES (s.pk, s.bonus, 'newdep')
3747+
|""".stripMargin
3748+
3749+
sql(mergeStmt)
3750+
// bonus column should NOT be added to target schema
3751+
// Only salary is updated with bonus value
3752+
checkAnswer(
3753+
sql(s"SELECT * FROM $tableNameAsString"),
3754+
Seq(
3755+
Row(1, 100, "hr"),
3756+
Row(2, 50, "software"),
3757+
Row(3, 75, "newdep")))
3758+
3759+
sql(s"DROP TABLE $tableNameAsString")
3760+
}
3761+
}
3762+
}
3763+
37243764
test("Merge schema evolution should evolve struct if directly referencing new field " +
37253765
"in top level struct: insert") {
37263766
Seq(true, false).foreach { withSchemaEvolution =>
@@ -3991,6 +4031,50 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
39914031
}
39924032
}
39934033

4034+
test("merge into with source missing fields in top-level struct") {
4035+
withTempView("source") {
4036+
// Target table has struct with 3 fields at top level
4037+
createAndInitTable(
4038+
s"""pk INT NOT NULL,
4039+
|s STRUCT<c1: INT, c2: STRING, c3: BOOLEAN>,
4040+
|dep STRING""".stripMargin,
4041+
"""{ "pk": 0, "s": { "c1": 1, "c2": "a", "c3": true }, "dep": "sales"}""")
4042+
4043+
// Source table has struct with only 2 fields (c1, c2) - missing c3
4044+
val sourceTableSchema = StructType(Seq(
4045+
StructField("pk", IntegerType, nullable = false),
4046+
StructField("s", StructType(Seq(
4047+
StructField("c1", IntegerType),
4048+
StructField("c2", StringType)))), // missing c3 field
4049+
StructField("dep", StringType)))
4050+
val data = Seq(
4051+
Row(1, Row(10, "b"), "hr"),
4052+
Row(2, Row(20, "c"), "engineering")
4053+
)
4054+
spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema)
4055+
.createOrReplaceTempView("source")
4056+
4057+
sql(
4058+
s"""MERGE INTO $tableNameAsString t
4059+
|USING source src
4060+
|ON t.pk = src.pk
4061+
|WHEN MATCHED THEN
4062+
| UPDATE SET *
4063+
|WHEN NOT MATCHED THEN
4064+
| INSERT *
4065+
|""".stripMargin)
4066+
4067+
// Missing field c3 should be filled with NULL
4068+
checkAnswer(
4069+
sql(s"SELECT * FROM $tableNameAsString"),
4070+
Seq(
4071+
Row(0, Row(1, "a", true), "sales"),
4072+
Row(1, Row(10, "b", null), "hr"),
4073+
Row(2, Row(20, "c", null), "engineering")))
4074+
}
4075+
sql(s"DROP TABLE IF EXISTS $tableNameAsString")
4076+
}
4077+
39944078
test("merge into with source missing fields in struct nested in array") {
39954079
withTempView("source") {
39964080
// Target table has struct with 3 fields (c1, c2, c3) in array

0 commit comments

Comments
 (0)