88 "github.com/sqlc-dev/sqlc/internal/sql/sqlerr"
99)
1010
11+ const excludedTable = "EXCLUDED"
12+
1113func InsertStmt (c * catalog.Catalog , fqn * ast.TableName , stmt * ast.InsertStmt ) error {
1214 sel , ok := stmt .SelectStmt .(* ast.SelectStmt )
1315 if ! ok {
@@ -38,6 +40,7 @@ func InsertStmt(c *catalog.Catalog, fqn *ast.TableName, stmt *ast.InsertStmt) er
3840 Message : "INSERT has more expressions than target columns" ,
3941 }
4042 }
43+
4144 return onConflictClause (c , fqn , stmt )
4245}
4346
@@ -47,7 +50,7 @@ func InsertStmt(c *catalog.Catalog, fqn *ast.TableName, stmt *ast.InsertStmt) er
4750// - DO UPDATE SET col = ... assignment target columns exist
4851// - EXCLUDED.col references exist
4952func onConflictClause (c * catalog.Catalog , fqn * ast.TableName , n * ast.InsertStmt ) error {
50- if n .OnConflictClause == nil || n .OnConflictClause .Action != ast .OnConflictActionUpdate {
53+ if fqn == nil || n .OnConflictClause == nil || n .OnConflictClause .Action != ast .OnConflictActionUpdate {
5154 return nil
5255 }
5356
@@ -62,13 +65,22 @@ func onConflictClause(c *catalog.Catalog, fqn *ast.TableName, n *ast.InsertStmt)
6265 colNames [col .Name ] = struct {}{}
6366 }
6467
68+ // DO UPDATE requires a conflict target: ON CONFLICT (col) or ON CONFLICT ON CONSTRAINT name.
69+ if n .OnConflictClause .Infer == nil {
70+ return & sqlerr.Error {
71+ Code : "42601" ,
72+ Message : "ON CONFLICT DO UPDATE requires inference specification or constraint name" ,
73+ }
74+ }
75+
6576 // Validate ON CONFLICT (col, ...) conflict target columns.
66- if n .OnConflictClause .Infer != nil && n . OnConflictClause . Infer .IndexElems != nil {
77+ if n .OnConflictClause .Infer .IndexElems != nil {
6778 for _ , item := range n .OnConflictClause .Infer .IndexElems .Items {
6879 elem , ok := item .(* ast.IndexElem )
6980 if ! ok || elem .Name == nil {
7081 continue
7182 }
83+
7284 if _ , exists := colNames [* elem .Name ]; ! exists {
7385 e := sqlerr .ColumnNotFound (table .Rel .Name , * elem .Name )
7486 e .Location = n .OnConflictClause .Infer .Location
@@ -81,26 +93,30 @@ func onConflictClause(c *catalog.Catalog, fqn *ast.TableName, n *ast.InsertStmt)
8193 if n .OnConflictClause .TargetList == nil {
8294 return nil
8395 }
96+
8497 for _ , item := range n .OnConflictClause .TargetList .Items {
8598 target , ok := item .(* ast.ResTarget )
8699 if ! ok || target .Name == nil {
87100 continue
88101 }
102+
89103 if _ , exists := colNames [* target .Name ]; ! exists {
90104 e := sqlerr .ColumnNotFound (table .Rel .Name , * target .Name )
91105 e .Location = target .Location
92106 return e
93107 }
108+
94109 if ref , ok := target .Val .(* ast.ColumnRef ); ok {
95110 if excludedCol , ok := excludedColumnRef (ref ); ok {
96111 if _ , exists := colNames [excludedCol ]; ! exists {
97- e := sqlerr .ColumnNotFound (table . Rel . Name , excludedCol )
112+ e := sqlerr .ColumnNotFound (excludedTable , excludedCol )
98113 e .Location = ref .Location
99114 return e
100115 }
101116 }
102117 }
103118 }
119+
104120 return nil
105121}
106122
@@ -110,13 +126,16 @@ func excludedColumnRef(ref *ast.ColumnRef) (string, bool) {
110126 if ref .Fields == nil || len (ref .Fields .Items ) != 2 {
111127 return "" , false
112128 }
129+
113130 first , ok := ref .Fields .Items [0 ].(* ast.String )
114- if ! ok || ! strings .EqualFold (first .Str , "excluded" ) {
131+ if ! ok || ! strings .EqualFold (first .Str , excludedTable ) {
115132 return "" , false
116133 }
134+
117135 second , ok := ref .Fields .Items [1 ].(* ast.String )
118136 if ! ok {
119137 return "" , false
120138 }
139+
121140 return second .Str , true
122141}
0 commit comments