Skip to content

Commit 6fd7e4f

Browse files
committed
test: Add tests for multi statement parsing
1 parent e8edcb1 commit 6fd7e4f

File tree

1 file changed

+240
-3
lines changed

1 file changed

+240
-3
lines changed

database/firebird/firebird_test.go

Lines changed: 240 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,17 @@ import (
55
"database/sql"
66
sqldriver "database/sql/driver"
77
"fmt"
8-
"log"
9-
10-
"github.com/golang-migrate/migrate/v4"
118
"io"
9+
"log"
10+
nurl "net/url"
11+
"strconv"
1212
"strings"
1313
"testing"
1414

1515
"github.com/dhui/dktest"
1616

17+
"github.com/golang-migrate/migrate/v4"
18+
"github.com/golang-migrate/migrate/v4/database/multistmt"
1719
dt "github.com/golang-migrate/migrate/v4/database/testing"
1820
"github.com/golang-migrate/migrate/v4/dktesting"
1921
_ "github.com/golang-migrate/migrate/v4/source/file"
@@ -126,6 +128,75 @@ func TestMigrate(t *testing.T) {
126128
})
127129
}
128130

131+
func TestMultipleStatements(t *testing.T) {
132+
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
133+
ip, port, err := c.FirstPort()
134+
if err != nil {
135+
t.Fatal(err)
136+
}
137+
138+
addr := fbConnectionString(ip, port)
139+
p := &Firebird{}
140+
d, err := p.Open(addr)
141+
if err != nil {
142+
t.Fatal(err)
143+
}
144+
defer func() {
145+
if err := d.Close(); err != nil {
146+
t.Error(err)
147+
}
148+
}()
149+
if err := d.Run(strings.NewReader("CREATE TABLE foo (foo VARCHAR(40)); CREATE TABLE bar (bar VARCHAR(40));")); err != nil {
150+
t.Fatalf("expected err to be nil, got %v", err)
151+
}
152+
153+
// make sure second table exists
154+
var exists bool
155+
query := "SELECT CASE WHEN EXISTS (SELECT 1 FROM RDB$RELATIONS WHERE RDB$RELATION_NAME = 'BAR') THEN 1 ELSE 0 END FROM RDB$DATABASE"
156+
if err := d.(*Firebird).conn.QueryRowContext(context.Background(), query).Scan(&exists); err != nil {
157+
t.Fatal(err)
158+
}
159+
if !exists {
160+
t.Fatalf("expected table bar to exist")
161+
}
162+
})
163+
}
164+
165+
func TestMultipleStatementsInMultiStatementMode(t *testing.T) {
166+
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
167+
ip, port, err := c.FirstPort()
168+
if err != nil {
169+
t.Fatal(err)
170+
}
171+
172+
addr := fbConnectionString(ip, port) + "?x-multi-statement=true"
173+
p := &Firebird{}
174+
d, err := p.Open(addr)
175+
if err != nil {
176+
t.Fatal(err)
177+
}
178+
defer func() {
179+
if err := d.Close(); err != nil {
180+
t.Error(err)
181+
}
182+
}()
183+
// Use CREATE INDEX instead of CONCURRENTLY (Firebird doesn't support CREATE INDEX CONCURRENTLY)
184+
if err := d.Run(strings.NewReader("CREATE TABLE foo (foo VARCHAR(40)); CREATE INDEX idx_foo ON foo (foo);")); err != nil {
185+
t.Fatalf("expected err to be nil, got %v", err)
186+
}
187+
188+
// make sure created index exists
189+
var exists bool
190+
query := "SELECT CASE WHEN EXISTS (SELECT 1 FROM RDB$INDICES WHERE RDB$INDEX_NAME = 'IDX_FOO') THEN 1 ELSE 0 END FROM RDB$DATABASE"
191+
if err := d.(*Firebird).conn.QueryRowContext(context.Background(), query).Scan(&exists); err != nil {
192+
t.Fatal(err)
193+
}
194+
if !exists {
195+
t.Fatalf("expected index idx_foo to exist")
196+
}
197+
})
198+
}
199+
129200
func TestErrorParsing(t *testing.T) {
130201
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
131202
ip, port, err := c.FirstPort()
@@ -225,3 +296,169 @@ func Test_Lock(t *testing.T) {
225296
}
226297
})
227298
}
299+
300+
func TestMultiStatementURLParsing(t *testing.T) {
301+
tests := []struct {
302+
name string
303+
url string
304+
expectedMultiStmt bool
305+
expectedMultiStmtSize int
306+
shouldError bool
307+
}{
308+
{
309+
name: "multi-statement enabled",
310+
url: "firebird://user:pass@localhost:3050//path/to/db.fdb?x-multi-statement=true",
311+
expectedMultiStmt: true,
312+
expectedMultiStmtSize: DefaultMultiStatementMaxSize,
313+
shouldError: false,
314+
},
315+
{
316+
name: "multi-statement disabled",
317+
url: "firebird://user:pass@localhost:3050//path/to/db.fdb?x-multi-statement=false",
318+
expectedMultiStmt: false,
319+
expectedMultiStmtSize: DefaultMultiStatementMaxSize,
320+
shouldError: false,
321+
},
322+
{
323+
name: "multi-statement with custom size",
324+
url: "firebird://user:pass@localhost:3050//path/to/db.fdb?x-multi-statement=true&x-multi-statement-max-size=5242880",
325+
expectedMultiStmt: true,
326+
expectedMultiStmtSize: 5242880,
327+
shouldError: false,
328+
},
329+
{
330+
name: "multi-statement with invalid size falls back to default",
331+
url: "firebird://user:pass@localhost:3050//path/to/db.fdb?x-multi-statement=true&x-multi-statement-max-size=0",
332+
expectedMultiStmt: true,
333+
expectedMultiStmtSize: DefaultMultiStatementMaxSize,
334+
shouldError: false,
335+
},
336+
{
337+
name: "invalid boolean value should error",
338+
url: "firebird://user:pass@localhost:3050//path/to/db.fdb?x-multi-statement=invalid",
339+
expectedMultiStmt: false,
340+
expectedMultiStmtSize: DefaultMultiStatementMaxSize,
341+
shouldError: true,
342+
},
343+
{
344+
name: "invalid size value should error",
345+
url: "firebird://user:pass@localhost:3050//path/to/db.fdb?x-multi-statement=true&x-multi-statement-max-size=invalid",
346+
expectedMultiStmt: true,
347+
expectedMultiStmtSize: DefaultMultiStatementMaxSize,
348+
shouldError: true,
349+
},
350+
}
351+
352+
for _, tt := range tests {
353+
t.Run(tt.name, func(t *testing.T) {
354+
// We can't actually open a database connection without Docker,
355+
// but we can test the URL parsing logic by examining how Open would behave
356+
purl, err := nurl.Parse(tt.url)
357+
if err != nil {
358+
if !tt.shouldError {
359+
t.Fatalf("parseURL failed: %v", err)
360+
}
361+
return
362+
}
363+
364+
// Test multi-statement parameter parsing
365+
multiStatementEnabled := false
366+
multiStatementMaxSize := DefaultMultiStatementMaxSize
367+
368+
if s := purl.Query().Get("x-multi-statement"); len(s) > 0 {
369+
multiStatementEnabled, err = strconv.ParseBool(s)
370+
if err != nil {
371+
if tt.shouldError {
372+
return // Expected error
373+
}
374+
t.Fatalf("unable to parse option x-multi-statement: %v", err)
375+
}
376+
}
377+
378+
if s := purl.Query().Get("x-multi-statement-max-size"); len(s) > 0 {
379+
multiStatementMaxSize, err = strconv.Atoi(s)
380+
if err != nil {
381+
if tt.shouldError {
382+
return // Expected error
383+
}
384+
t.Fatalf("unable to parse x-multi-statement-max-size: %v", err)
385+
}
386+
if multiStatementMaxSize <= 0 {
387+
multiStatementMaxSize = DefaultMultiStatementMaxSize
388+
}
389+
}
390+
391+
if tt.shouldError {
392+
t.Fatalf("expected error but got none")
393+
}
394+
395+
if multiStatementEnabled != tt.expectedMultiStmt {
396+
t.Errorf("expected MultiStatementEnabled to be %v, got %v", tt.expectedMultiStmt, multiStatementEnabled)
397+
}
398+
399+
if multiStatementMaxSize != tt.expectedMultiStmtSize {
400+
t.Errorf("expected MultiStatementMaxSize to be %d, got %d", tt.expectedMultiStmtSize, multiStatementMaxSize)
401+
}
402+
})
403+
}
404+
}
405+
406+
func TestMultiStatementParsing(t *testing.T) {
407+
tests := []struct {
408+
name string
409+
input string
410+
expected []string
411+
}{
412+
{
413+
name: "single statement",
414+
input: "CREATE TABLE test (id INTEGER);",
415+
expected: []string{"CREATE TABLE test (id INTEGER);"},
416+
},
417+
{
418+
name: "multiple statements",
419+
input: "CREATE TABLE foo (id INTEGER); CREATE TABLE bar (name VARCHAR(50));",
420+
expected: []string{"CREATE TABLE foo (id INTEGER);", "CREATE TABLE bar (name VARCHAR(50));"},
421+
},
422+
{
423+
name: "statements with whitespace",
424+
input: "CREATE TABLE foo (id INTEGER);\n\n CREATE TABLE bar (name VARCHAR(50)); \n",
425+
expected: []string{"CREATE TABLE foo (id INTEGER);", "CREATE TABLE bar (name VARCHAR(50));"},
426+
},
427+
{
428+
name: "empty statements ignored",
429+
input: "CREATE TABLE foo (id INTEGER);;CREATE TABLE bar (name VARCHAR(50));",
430+
expected: []string{"CREATE TABLE foo (id INTEGER);", "CREATE TABLE bar (name VARCHAR(50));"},
431+
},
432+
}
433+
434+
for _, tt := range tests {
435+
t.Run(tt.name, func(t *testing.T) {
436+
var statements []string
437+
reader := strings.NewReader(tt.input)
438+
439+
// Simulate what the Firebird driver does with multi-statement parsing
440+
err := multistmt.Parse(reader, multiStmtDelimiter, DefaultMultiStatementMaxSize, func(stmt []byte) bool {
441+
query := strings.TrimSpace(string(stmt))
442+
// Skip empty statements and standalone semicolons
443+
if len(query) > 0 && query != ";" {
444+
statements = append(statements, query)
445+
}
446+
return true // continue parsing
447+
})
448+
449+
if err != nil {
450+
t.Fatalf("parsing failed: %v", err)
451+
}
452+
453+
if len(statements) != len(tt.expected) {
454+
t.Fatalf("expected %d statements, got %d: %v", len(tt.expected), len(statements), statements)
455+
}
456+
457+
for i, expected := range tt.expected {
458+
if statements[i] != expected {
459+
t.Errorf("statement %d: expected %q, got %q", i, expected, statements[i])
460+
}
461+
}
462+
})
463+
}
464+
}

0 commit comments

Comments
 (0)