Skip to content

Commit a8737a7

Browse files
committed
fix motherduck duckdb model executor (#6943)
1 parent fdddcb3 commit a8737a7

File tree

7 files changed

+33
-18
lines changed

7 files changed

+33
-18
lines changed

runtime/drivers/duckdb/model_executor_motherduck_self.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ func (e *mdToSelfExecutor) Execute(ctx context.Context, opts *drivers.ModelExecu
7676

7777
clone := *opts
7878
m := &ModelInputProperties{
79-
SQL: inputProps.SQL,
80-
PreExec: fmt.Sprintf("INSTALL 'motherduck'; LOAD 'motherduck'; SET motherduck_token=%s; ATTACH %s;", safeSQLString(token), safeSQLString(inputProps.resolveDSN())),
79+
SQL: inputProps.SQL,
80+
InitQueries: fmt.Sprintf("INSTALL 'motherduck'; LOAD 'motherduck'; SET motherduck_token=%s; ATTACH %s;", safeSQLString(token), safeSQLString(inputProps.resolveDSN())),
8181
}
8282
var props map[string]any
8383
err = mapstructure.Decode(m, &props)

runtime/drivers/duckdb/model_executor_self.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,9 @@ func (e *selfToSelfExecutor) Execute(ctx context.Context, opts *drivers.ModelExe
135135
BeforeCreate: inputProps.PreExec,
136136
AfterCreate: inputProps.PostExec,
137137
}
138+
if inputProps.InitQueries != "" {
139+
createTableOpts.InitQueries = []string{inputProps.InitQueries}
140+
}
138141
res, err := olap.CreateTableAsSelect(ctx, stagingTableName, inputProps.SQL, createTableOpts)
139142
if err != nil {
140143
_ = olap.DropTable(ctx, stagingTableName)
@@ -160,6 +163,9 @@ func (e *selfToSelfExecutor) Execute(ctx context.Context, opts *drivers.ModelExe
160163
Strategy: outputProps.IncrementalStrategy,
161164
UniqueKey: outputProps.UniqueKey,
162165
}
166+
if inputProps.InitQueries != "" {
167+
insertTableOpts.InitQueries = []string{inputProps.InitQueries}
168+
}
163169
res, err := olap.InsertTableAsSelect(ctx, tableName, inputProps.SQL, insertTableOpts)
164170
if err != nil {
165171
return nil, fmt.Errorf("failed to incrementally insert into table: %w", err)

runtime/drivers/duckdb/model_manager.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@ import (
1010
)
1111

1212
type ModelInputProperties struct {
13-
SQL string `mapstructure:"sql"`
14-
Args []any `mapstructure:"args"`
15-
PreExec string `mapstructure:"pre_exec"`
16-
PostExec string `mapstructure:"post_exec"`
13+
SQL string `mapstructure:"sql"`
14+
Args []any `mapstructure:"args"`
15+
// InitQueries are queries that are run during initialisation of write handle before model is created of any pre_exec queries are run.
16+
InitQueries string `mapstructure:"init_queries"`
17+
PreExec string `mapstructure:"pre_exec"`
18+
PostExec string `mapstructure:"post_exec"`
1719
// Database is set if sql is to be run against an external database
1820
Database string `mapstructure:"db"`
1921
}

runtime/drivers/duckdb/olap.go

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ func (c *connection) AddTableColumn(ctx context.Context, tableName, columnName,
195195
_ = release()
196196
}()
197197

198-
_, err = db.MutateTable(ctx, tableName, func(ctx context.Context, conn *sqlx.Conn) error {
198+
_, err = db.MutateTable(ctx, tableName, nil, func(ctx context.Context, conn *sqlx.Conn) error {
199199
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s", safeSQLName(tableName), safeSQLName(columnName), typ))
200200
return err
201201
})
@@ -212,7 +212,7 @@ func (c *connection) AlterTableColumn(ctx context.Context, tableName, columnName
212212
_ = release()
213213
}()
214214

215-
_, err = db.MutateTable(ctx, tableName, func(ctx context.Context, conn *sqlx.Conn) error {
215+
_, err = db.MutateTable(ctx, tableName, nil, func(ctx context.Context, conn *sqlx.Conn) error {
216216
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ALTER %s TYPE %s", safeSQLName(tableName), safeSQLName(columnName), newType))
217217
return err
218218
})
@@ -242,7 +242,12 @@ func (c *connection) CreateTableAsSelect(ctx context.Context, name, sql string,
242242
return err
243243
}
244244
}
245-
res, err := db.CreateTableAsSelect(ctx, name, sql, &rduckdb.CreateTableOptions{View: opts.View, BeforeCreateFn: beforeCreateFn, AfterCreateFn: afterCreateFn})
245+
res, err := db.CreateTableAsSelect(ctx, name, sql, &rduckdb.CreateTableOptions{
246+
View: opts.View,
247+
InitQueries: opts.InitQueries,
248+
BeforeCreateFn: beforeCreateFn,
249+
AfterCreateFn: afterCreateFn,
250+
})
246251
if err != nil {
247252
return nil, c.checkErr(err)
248253
}
@@ -266,7 +271,7 @@ func (c *connection) InsertTableAsSelect(ctx context.Context, name, sql string,
266271
}
267272

268273
if opts.Strategy == drivers.IncrementalStrategyAppend {
269-
res, err := db.MutateTable(ctx, name, func(ctx context.Context, conn *sqlx.Conn) error {
274+
res, err := db.MutateTable(ctx, name, opts.InitQueries, func(ctx context.Context, conn *sqlx.Conn) error {
270275
if opts.BeforeInsert != "" {
271276
_, err := conn.ExecContext(ctx, opts.BeforeInsert)
272277
if err != nil {
@@ -289,7 +294,7 @@ func (c *connection) InsertTableAsSelect(ctx context.Context, name, sql string,
289294
}
290295

291296
if opts.Strategy == drivers.IncrementalStrategyMerge {
292-
res, err := db.MutateTable(ctx, name, func(ctx context.Context, conn *sqlx.Conn) (mutate error) {
297+
res, err := db.MutateTable(ctx, name, opts.InitQueries, func(ctx context.Context, conn *sqlx.Conn) (mutate error) {
293298
// Execute the pre-init SQL first
294299
if opts.BeforeInsert != "" {
295300
_, err := conn.ExecContext(ctx, opts.BeforeInsert)

runtime/drivers/olap.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ type WithConnectionFunc func(wrappedCtx context.Context, ensuredCtx context.Cont
3333

3434
type CreateTableOptions struct {
3535
View bool
36+
InitQueries []string
3637
BeforeCreate string
3738
AfterCreate string
3839
TableOpts map[string]any
@@ -45,6 +46,7 @@ type TableWriteMetrics struct {
4546
}
4647

4748
type InsertTableOptions struct {
49+
InitQueries []string
4850
BeforeInsert string
4951
AfterInsert string
5052
ByName bool

runtime/pkg/rduckdb/db.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ type DB interface {
4949
CreateTableAsSelect(ctx context.Context, name string, sql string, opts *CreateTableOptions) (*TableWriteMetrics, error)
5050

5151
// MutateTable allows mutating a table in the database by calling the mutateFn.
52-
MutateTable(ctx context.Context, name string, mutateFn func(ctx context.Context, conn *sqlx.Conn) error) (*TableWriteMetrics, error)
52+
MutateTable(ctx context.Context, name string, initQueries []string, mutateFn func(ctx context.Context, conn *sqlx.Conn) error) (*TableWriteMetrics, error)
5353

5454
// DropTable removes a table from the database.
5555
DropTable(ctx context.Context, name string) error
@@ -455,7 +455,7 @@ func (d *db) CreateTableAsSelect(ctx context.Context, name, query string, opts *
455455
return &TableWriteMetrics{Duration: duration}, nil
456456
}
457457

458-
func (d *db) MutateTable(ctx context.Context, name string, mutateFn func(ctx context.Context, conn *sqlx.Conn) error) (*TableWriteMetrics, error) {
458+
func (d *db) MutateTable(ctx context.Context, name string, initQueries []string, mutateFn func(ctx context.Context, conn *sqlx.Conn) error) (*TableWriteMetrics, error) {
459459
d.logger.Debug("mutate table", zap.String("name", name), observability.ZapCtx(ctx))
460460
err := d.writeSem.Acquire(ctx, 1)
461461
if err != nil {
@@ -488,7 +488,7 @@ func (d *db) MutateTable(ctx context.Context, name string, mutateFn func(ctx con
488488

489489
// acquire write connection
490490
// need to ignore attaching table since it is already present in the db file
491-
conn, release, err := d.acquireWriteConn(ctx, d.localDBPath(name, newVersion), name, nil, false)
491+
conn, release, err := d.acquireWriteConn(ctx, d.localDBPath(name, newVersion), name, initQueries, false)
492492
if err != nil {
493493
_ = os.RemoveAll(newDir)
494494
return nil, err

runtime/pkg/rduckdb/db_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ func TestDB(t *testing.T) {
4242
require.Error(t, err)
4343

4444
// insert into table
45-
_, err = db.MutateTable(ctx, "test2", func(ctx context.Context, conn *sqlx.Conn) error {
45+
_, err = db.MutateTable(ctx, "test2", nil, func(ctx context.Context, conn *sqlx.Conn) error {
4646
_, err := conn.ExecContext(ctx, "INSERT INTO test2 (id, country) VALUES (2, 'USA')")
4747
return err
4848
})
@@ -58,7 +58,7 @@ func TestDB(t *testing.T) {
5858
require.NoError(t, release())
5959

6060
// Add column
61-
_, err = db.MutateTable(ctx, "test2", func(ctx context.Context, conn *sqlx.Conn) error {
61+
_, err = db.MutateTable(ctx, "test2", nil, func(ctx context.Context, conn *sqlx.Conn) error {
6262
_, err := conn.ExecContext(ctx, "ALTER TABLE test2 ADD COLUMN city TEXT")
6363
return err
6464
})
@@ -157,7 +157,7 @@ func TestMutateTable(t *testing.T) {
157157
require.NoError(t, err)
158158

159159
// insert into table
160-
_, err = db.MutateTable(ctx, "test", func(ctx context.Context, conn *sqlx.Conn) error {
160+
_, err = db.MutateTable(ctx, "test", nil, func(ctx context.Context, conn *sqlx.Conn) error {
161161
_, err := conn.ExecContext(ctx, "INSERT INTO test (id, city) VALUES (2, 'NY')")
162162
return err
163163
})
@@ -170,7 +170,7 @@ func TestMutateTable(t *testing.T) {
170170
testDone := make(chan struct{})
171171

172172
go func() {
173-
db.MutateTable(ctx, "test", func(ctx context.Context, conn *sqlx.Conn) error {
173+
db.MutateTable(ctx, "test", nil, func(ctx context.Context, conn *sqlx.Conn) error {
174174
_, err := conn.ExecContext(ctx, "ALTER TABLE test ADD COLUMN country TEXT")
175175
require.NoError(t, err)
176176
_, err = conn.ExecContext(ctx, "UPDATE test SET country = 'USA' WHERE id = 2")

0 commit comments

Comments
 (0)