diff --git a/templates/template.go b/templates/template.go index fc92987..e0a194c 100644 --- a/templates/template.go +++ b/templates/template.go @@ -67,6 +67,7 @@ func buildBulkInsertQuery(originalQuery string, numArgs int, numParamsPerArg int // Use LastIndex to find the main clause onDuplicateUpperIndex := strings.LastIndex(strings.ToUpper(trimmedQuery), "ON DUPLICATE KEY UPDATE") onConflictUpperIndex := strings.LastIndex(strings.ToUpper(trimmedQuery), "ON CONFLICT") + returningUpperIndex := strings.LastIndex(strings.ToUpper(trimmedQuery), "RETURNING") // Find the earliest starting position of any suffix keyword suffixBoundary := len(trimmedQuery) @@ -76,6 +77,9 @@ func buildBulkInsertQuery(originalQuery string, numArgs int, numParamsPerArg int if onConflictUpperIndex != -1 && onConflictUpperIndex < suffixBoundary { suffixBoundary = onConflictUpperIndex } + if returningUpperIndex != -1 && returningUpperIndex < suffixBoundary { + suffixBoundary = returningUpperIndex + } if suffixBoundary < len(trimmedQuery) { // Suffix found diff --git a/templates/template_test.go b/templates/template_test.go index 2c20273..68597bf 100644 --- a/templates/template_test.go +++ b/templates/template_test.go @@ -161,6 +161,30 @@ func TestBuildBulkInsertQuery(t *testing.T) { } }, }, + "valid:upsert (ON DUPLICATE KEY UPDATE)": { + arrange: func(t *testing.T) (Args, Expected) { + return Args{ + originalQuery: "INSERT INTO users (id, name) VALUES (?, ?) ON DUPLICATE KEY UPDATE id = VALUES(id), name = VALUES(name);", + numArgs: 2, + numParamsPerArg: 2, + }, Expected{ + query: "INSERT INTO users (id, name) VALUES (?,?),(?,?) ON DUPLICATE KEY UPDATE id = VALUES(id), name = VALUES(name)", + err: nil, + } + }, + }, + "valid:upsert (ON DUPLICATE KEY UPDATE) case-insensitive": { + arrange: func(t *testing.T) (Args, Expected) { + return Args{ + originalQuery: "insert into users (id, name) values (?, ?) on duplicate key update id = values(id), name = values(name);", + numArgs: 2, + numParamsPerArg: 2, + }, Expected{ + query: "insert into users (id, name) VALUES (?,?),(?,?) on duplicate key update id = values(id), name = values(name)", + err: nil, + } + }, + }, "valid:upsert (ON CONFLICT)": { arrange: func(t *testing.T) (Args, Expected) { return Args{ @@ -185,6 +209,30 @@ func TestBuildBulkInsertQuery(t *testing.T) { } }, }, + "valid:RETURNING clause": { + arrange: func(t *testing.T) (Args, Expected) { + return Args{ + originalQuery: "INSERT INTO users (id, name) VALUES (?, ?) RETURNING id;", + numArgs: 2, + numParamsPerArg: 2, + }, Expected{ + query: "INSERT INTO users (id, name) VALUES (?,?),(?,?) RETURNING id", + err: nil, + } + }, + }, + "valid:RETURNING clause with ON CONFLICT": { + arrange: func(t *testing.T) (Args, Expected) { + return Args{ + originalQuery: "INSERT INTO users (id, name) VALUES (?, ?) ON CONFLICT (id) DO UPDATE SET name = EXCLUDED.name RETURNING id;", + numArgs: 2, + numParamsPerArg: 2, + }, Expected{ + query: "INSERT INTO users (id, name) VALUES (?,?),(?,?) ON CONFLICT (id) DO UPDATE SET name = EXCLUDED.name RETURNING id", + err: nil, + } + }, + }, "valid:Squeeze spaces": { arrange: func(t *testing.T) (Args, Expected) { return Args{