Skip to content

Commit c3d9922

Browse files
committed
ON CONSTRAINT
1 parent dc37712 commit c3d9922

File tree

3 files changed

+34
-5
lines changed

3 files changed

+34
-5
lines changed

internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/query.sql

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,11 @@ INSERT INTO servers(code, name) VALUES ($1, $2)
1313
ON CONFLICT (code)
1414
DO UPDATE SET name = EXCLUDED.name_typo;
1515

16+
-- name: UpsertServerMissingConflictTarget :exec
17+
INSERT INTO servers(code, name) VALUES ($1, $2)
18+
ON CONFLICT DO UPDATE SET name = EXCLUDED.name;
19+
20+
-- name: UpsertServerOnConstraintExcludedTypo :exec
21+
INSERT INTO servers(code, name) VALUES ($1, $2)
22+
ON CONFLICT ON CONSTRAINT servers_pkey DO UPDATE SET name = EXCLUDED.name_typo;
23+
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# package querytest
22
query.sql:4:15: column "name_typo" of relation "servers" does not exist
33
query.sql:8:13: column "code_typo" of relation "servers" does not exist
4-
query.sql:14:22: column "name_typo" of relation "servers" does not exist
4+
query.sql:14:22: column "name_typo" of relation "EXCLUDED" does not exist
5+
query.sql:17:1: ON CONFLICT DO UPDATE requires inference specification or constraint name
6+
query.sql:22:61: column "name_typo" of relation "EXCLUDED" does not exist

internal/sql/validate/insert_stmt.go

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ import (
88
"github.com/sqlc-dev/sqlc/internal/sql/sqlerr"
99
)
1010

11+
const excludedTable = "EXCLUDED"
12+
1113
func InsertStmt(c *catalog.Catalog, fqn *ast.TableName, stmt *ast.InsertStmt) error {
1214
sel, ok := stmt.SelectStmt.(*ast.SelectStmt)
1315
if !ok {
@@ -38,6 +40,7 @@ func InsertStmt(c *catalog.Catalog, fqn *ast.TableName, stmt *ast.InsertStmt) er
3840
Message: "INSERT has more expressions than target columns",
3941
}
4042
}
43+
4144
return onConflictClause(c, fqn, stmt)
4245
}
4346

@@ -47,7 +50,7 @@ func InsertStmt(c *catalog.Catalog, fqn *ast.TableName, stmt *ast.InsertStmt) er
4750
// - DO UPDATE SET col = ... assignment target columns exist
4851
// - EXCLUDED.col references exist
4952
func onConflictClause(c *catalog.Catalog, fqn *ast.TableName, n *ast.InsertStmt) error {
50-
if n.OnConflictClause == nil || n.OnConflictClause.Action != ast.OnConflictActionUpdate {
53+
if fqn == nil || n.OnConflictClause == nil || n.OnConflictClause.Action != ast.OnConflictActionUpdate {
5154
return nil
5255
}
5356

@@ -62,13 +65,22 @@ func onConflictClause(c *catalog.Catalog, fqn *ast.TableName, n *ast.InsertStmt)
6265
colNames[col.Name] = struct{}{}
6366
}
6467

68+
// DO UPDATE requires a conflict target: ON CONFLICT (col) or ON CONFLICT ON CONSTRAINT name.
69+
if n.OnConflictClause.Infer == nil {
70+
return &sqlerr.Error{
71+
Code: "42601",
72+
Message: "ON CONFLICT DO UPDATE requires inference specification or constraint name",
73+
}
74+
}
75+
6576
// Validate ON CONFLICT (col, ...) conflict target columns.
66-
if n.OnConflictClause.Infer != nil && n.OnConflictClause.Infer.IndexElems != nil {
77+
if n.OnConflictClause.Infer.IndexElems != nil {
6778
for _, item := range n.OnConflictClause.Infer.IndexElems.Items {
6879
elem, ok := item.(*ast.IndexElem)
6980
if !ok || elem.Name == nil {
7081
continue
7182
}
83+
7284
if _, exists := colNames[*elem.Name]; !exists {
7385
e := sqlerr.ColumnNotFound(table.Rel.Name, *elem.Name)
7486
e.Location = n.OnConflictClause.Infer.Location
@@ -81,26 +93,30 @@ func onConflictClause(c *catalog.Catalog, fqn *ast.TableName, n *ast.InsertStmt)
8193
if n.OnConflictClause.TargetList == nil {
8294
return nil
8395
}
96+
8497
for _, item := range n.OnConflictClause.TargetList.Items {
8598
target, ok := item.(*ast.ResTarget)
8699
if !ok || target.Name == nil {
87100
continue
88101
}
102+
89103
if _, exists := colNames[*target.Name]; !exists {
90104
e := sqlerr.ColumnNotFound(table.Rel.Name, *target.Name)
91105
e.Location = target.Location
92106
return e
93107
}
108+
94109
if ref, ok := target.Val.(*ast.ColumnRef); ok {
95110
if excludedCol, ok := excludedColumnRef(ref); ok {
96111
if _, exists := colNames[excludedCol]; !exists {
97-
e := sqlerr.ColumnNotFound(table.Rel.Name, excludedCol)
112+
e := sqlerr.ColumnNotFound(excludedTable, excludedCol)
98113
e.Location = ref.Location
99114
return e
100115
}
101116
}
102117
}
103118
}
119+
104120
return nil
105121
}
106122

@@ -110,13 +126,16 @@ func excludedColumnRef(ref *ast.ColumnRef) (string, bool) {
110126
if ref.Fields == nil || len(ref.Fields.Items) != 2 {
111127
return "", false
112128
}
129+
113130
first, ok := ref.Fields.Items[0].(*ast.String)
114-
if !ok || !strings.EqualFold(first.Str, "excluded") {
131+
if !ok || !strings.EqualFold(first.Str, excludedTable) {
115132
return "", false
116133
}
134+
117135
second, ok := ref.Fields.Items[1].(*ast.String)
118136
if !ok {
119137
return "", false
120138
}
139+
121140
return second.Str, true
122141
}

0 commit comments

Comments
 (0)