@@ -5,15 +5,17 @@ import (
5
5
"database/sql"
6
6
sqldriver "database/sql/driver"
7
7
"fmt"
8
- "log"
9
-
10
- "github.com/golang-migrate/migrate/v4"
11
8
"io"
9
+ "log"
10
+ nurl "net/url"
11
+ "strconv"
12
12
"strings"
13
13
"testing"
14
14
15
15
"github.com/dhui/dktest"
16
16
17
+ "github.com/golang-migrate/migrate/v4"
18
+ "github.com/golang-migrate/migrate/v4/database/multistmt"
17
19
dt "github.com/golang-migrate/migrate/v4/database/testing"
18
20
"github.com/golang-migrate/migrate/v4/dktesting"
19
21
_ "github.com/golang-migrate/migrate/v4/source/file"
@@ -126,6 +128,75 @@ func TestMigrate(t *testing.T) {
126
128
})
127
129
}
128
130
131
+ func TestMultipleStatements (t * testing.T ) {
132
+ dktesting .ParallelTest (t , specs , func (t * testing.T , c dktest.ContainerInfo ) {
133
+ ip , port , err := c .FirstPort ()
134
+ if err != nil {
135
+ t .Fatal (err )
136
+ }
137
+
138
+ addr := fbConnectionString (ip , port )
139
+ p := & Firebird {}
140
+ d , err := p .Open (addr )
141
+ if err != nil {
142
+ t .Fatal (err )
143
+ }
144
+ defer func () {
145
+ if err := d .Close (); err != nil {
146
+ t .Error (err )
147
+ }
148
+ }()
149
+ if err := d .Run (strings .NewReader ("CREATE TABLE foo (foo VARCHAR(40)); CREATE TABLE bar (bar VARCHAR(40));" )); err != nil {
150
+ t .Fatalf ("expected err to be nil, got %v" , err )
151
+ }
152
+
153
+ // make sure second table exists
154
+ var exists bool
155
+ query := "SELECT CASE WHEN EXISTS (SELECT 1 FROM RDB$RELATIONS WHERE RDB$RELATION_NAME = 'BAR') THEN 1 ELSE 0 END FROM RDB$DATABASE"
156
+ if err := d .(* Firebird ).conn .QueryRowContext (context .Background (), query ).Scan (& exists ); err != nil {
157
+ t .Fatal (err )
158
+ }
159
+ if ! exists {
160
+ t .Fatalf ("expected table bar to exist" )
161
+ }
162
+ })
163
+ }
164
+
165
+ func TestMultipleStatementsInMultiStatementMode (t * testing.T ) {
166
+ dktesting .ParallelTest (t , specs , func (t * testing.T , c dktest.ContainerInfo ) {
167
+ ip , port , err := c .FirstPort ()
168
+ if err != nil {
169
+ t .Fatal (err )
170
+ }
171
+
172
+ addr := fbConnectionString (ip , port ) + "?x-multi-statement=true"
173
+ p := & Firebird {}
174
+ d , err := p .Open (addr )
175
+ if err != nil {
176
+ t .Fatal (err )
177
+ }
178
+ defer func () {
179
+ if err := d .Close (); err != nil {
180
+ t .Error (err )
181
+ }
182
+ }()
183
+ // Use CREATE INDEX instead of CONCURRENTLY (Firebird doesn't support CREATE INDEX CONCURRENTLY)
184
+ if err := d .Run (strings .NewReader ("CREATE TABLE foo (foo VARCHAR(40)); CREATE INDEX idx_foo ON foo (foo);" )); err != nil {
185
+ t .Fatalf ("expected err to be nil, got %v" , err )
186
+ }
187
+
188
+ // make sure created index exists
189
+ var exists bool
190
+ query := "SELECT CASE WHEN EXISTS (SELECT 1 FROM RDB$INDICES WHERE RDB$INDEX_NAME = 'IDX_FOO') THEN 1 ELSE 0 END FROM RDB$DATABASE"
191
+ if err := d .(* Firebird ).conn .QueryRowContext (context .Background (), query ).Scan (& exists ); err != nil {
192
+ t .Fatal (err )
193
+ }
194
+ if ! exists {
195
+ t .Fatalf ("expected index idx_foo to exist" )
196
+ }
197
+ })
198
+ }
199
+
129
200
func TestErrorParsing (t * testing.T ) {
130
201
dktesting .ParallelTest (t , specs , func (t * testing.T , c dktest.ContainerInfo ) {
131
202
ip , port , err := c .FirstPort ()
@@ -225,3 +296,169 @@ func Test_Lock(t *testing.T) {
225
296
}
226
297
})
227
298
}
299
+
300
+ func TestMultiStatementURLParsing (t * testing.T ) {
301
+ tests := []struct {
302
+ name string
303
+ url string
304
+ expectedMultiStmt bool
305
+ expectedMultiStmtSize int
306
+ shouldError bool
307
+ }{
308
+ {
309
+ name : "multi-statement enabled" ,
310
+ url : "firebird://user:pass@localhost:3050//path/to/db.fdb?x-multi-statement=true" ,
311
+ expectedMultiStmt : true ,
312
+ expectedMultiStmtSize : DefaultMultiStatementMaxSize ,
313
+ shouldError : false ,
314
+ },
315
+ {
316
+ name : "multi-statement disabled" ,
317
+ url : "firebird://user:pass@localhost:3050//path/to/db.fdb?x-multi-statement=false" ,
318
+ expectedMultiStmt : false ,
319
+ expectedMultiStmtSize : DefaultMultiStatementMaxSize ,
320
+ shouldError : false ,
321
+ },
322
+ {
323
+ name : "multi-statement with custom size" ,
324
+ url : "firebird://user:pass@localhost:3050//path/to/db.fdb?x-multi-statement=true&x-multi-statement-max-size=5242880" ,
325
+ expectedMultiStmt : true ,
326
+ expectedMultiStmtSize : 5242880 ,
327
+ shouldError : false ,
328
+ },
329
+ {
330
+ name : "multi-statement with invalid size falls back to default" ,
331
+ url : "firebird://user:pass@localhost:3050//path/to/db.fdb?x-multi-statement=true&x-multi-statement-max-size=0" ,
332
+ expectedMultiStmt : true ,
333
+ expectedMultiStmtSize : DefaultMultiStatementMaxSize ,
334
+ shouldError : false ,
335
+ },
336
+ {
337
+ name : "invalid boolean value should error" ,
338
+ url : "firebird://user:pass@localhost:3050//path/to/db.fdb?x-multi-statement=invalid" ,
339
+ expectedMultiStmt : false ,
340
+ expectedMultiStmtSize : DefaultMultiStatementMaxSize ,
341
+ shouldError : true ,
342
+ },
343
+ {
344
+ name : "invalid size value should error" ,
345
+ url : "firebird://user:pass@localhost:3050//path/to/db.fdb?x-multi-statement=true&x-multi-statement-max-size=invalid" ,
346
+ expectedMultiStmt : true ,
347
+ expectedMultiStmtSize : DefaultMultiStatementMaxSize ,
348
+ shouldError : true ,
349
+ },
350
+ }
351
+
352
+ for _ , tt := range tests {
353
+ t .Run (tt .name , func (t * testing.T ) {
354
+ // We can't actually open a database connection without Docker,
355
+ // but we can test the URL parsing logic by examining how Open would behave
356
+ purl , err := nurl .Parse (tt .url )
357
+ if err != nil {
358
+ if ! tt .shouldError {
359
+ t .Fatalf ("parseURL failed: %v" , err )
360
+ }
361
+ return
362
+ }
363
+
364
+ // Test multi-statement parameter parsing
365
+ multiStatementEnabled := false
366
+ multiStatementMaxSize := DefaultMultiStatementMaxSize
367
+
368
+ if s := purl .Query ().Get ("x-multi-statement" ); len (s ) > 0 {
369
+ multiStatementEnabled , err = strconv .ParseBool (s )
370
+ if err != nil {
371
+ if tt .shouldError {
372
+ return // Expected error
373
+ }
374
+ t .Fatalf ("unable to parse option x-multi-statement: %v" , err )
375
+ }
376
+ }
377
+
378
+ if s := purl .Query ().Get ("x-multi-statement-max-size" ); len (s ) > 0 {
379
+ multiStatementMaxSize , err = strconv .Atoi (s )
380
+ if err != nil {
381
+ if tt .shouldError {
382
+ return // Expected error
383
+ }
384
+ t .Fatalf ("unable to parse x-multi-statement-max-size: %v" , err )
385
+ }
386
+ if multiStatementMaxSize <= 0 {
387
+ multiStatementMaxSize = DefaultMultiStatementMaxSize
388
+ }
389
+ }
390
+
391
+ if tt .shouldError {
392
+ t .Fatalf ("expected error but got none" )
393
+ }
394
+
395
+ if multiStatementEnabled != tt .expectedMultiStmt {
396
+ t .Errorf ("expected MultiStatementEnabled to be %v, got %v" , tt .expectedMultiStmt , multiStatementEnabled )
397
+ }
398
+
399
+ if multiStatementMaxSize != tt .expectedMultiStmtSize {
400
+ t .Errorf ("expected MultiStatementMaxSize to be %d, got %d" , tt .expectedMultiStmtSize , multiStatementMaxSize )
401
+ }
402
+ })
403
+ }
404
+ }
405
+
406
+ func TestMultiStatementParsing (t * testing.T ) {
407
+ tests := []struct {
408
+ name string
409
+ input string
410
+ expected []string
411
+ }{
412
+ {
413
+ name : "single statement" ,
414
+ input : "CREATE TABLE test (id INTEGER);" ,
415
+ expected : []string {"CREATE TABLE test (id INTEGER);" },
416
+ },
417
+ {
418
+ name : "multiple statements" ,
419
+ input : "CREATE TABLE foo (id INTEGER); CREATE TABLE bar (name VARCHAR(50));" ,
420
+ expected : []string {"CREATE TABLE foo (id INTEGER);" , "CREATE TABLE bar (name VARCHAR(50));" },
421
+ },
422
+ {
423
+ name : "statements with whitespace" ,
424
+ input : "CREATE TABLE foo (id INTEGER);\n \n CREATE TABLE bar (name VARCHAR(50)); \n " ,
425
+ expected : []string {"CREATE TABLE foo (id INTEGER);" , "CREATE TABLE bar (name VARCHAR(50));" },
426
+ },
427
+ {
428
+ name : "empty statements ignored" ,
429
+ input : "CREATE TABLE foo (id INTEGER);;CREATE TABLE bar (name VARCHAR(50));" ,
430
+ expected : []string {"CREATE TABLE foo (id INTEGER);" , "CREATE TABLE bar (name VARCHAR(50));" },
431
+ },
432
+ }
433
+
434
+ for _ , tt := range tests {
435
+ t .Run (tt .name , func (t * testing.T ) {
436
+ var statements []string
437
+ reader := strings .NewReader (tt .input )
438
+
439
+ // Simulate what the Firebird driver does with multi-statement parsing
440
+ err := multistmt .Parse (reader , multiStmtDelimiter , DefaultMultiStatementMaxSize , func (stmt []byte ) bool {
441
+ query := strings .TrimSpace (string (stmt ))
442
+ // Skip empty statements and standalone semicolons
443
+ if len (query ) > 0 && query != ";" {
444
+ statements = append (statements , query )
445
+ }
446
+ return true // continue parsing
447
+ })
448
+
449
+ if err != nil {
450
+ t .Fatalf ("parsing failed: %v" , err )
451
+ }
452
+
453
+ if len (statements ) != len (tt .expected ) {
454
+ t .Fatalf ("expected %d statements, got %d: %v" , len (tt .expected ), len (statements ), statements )
455
+ }
456
+
457
+ for i , expected := range tt .expected {
458
+ if statements [i ] != expected {
459
+ t .Errorf ("statement %d: expected %q, got %q" , i , expected , statements [i ])
460
+ }
461
+ }
462
+ })
463
+ }
464
+ }
0 commit comments