Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix motherduck duckdb model executor #6943

Merged
merged 1 commit into from
Mar 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions runtime/drivers/duckdb/model_executor_motherduck_self.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ func (e *mdToSelfExecutor) Execute(ctx context.Context, opts *drivers.ModelExecu

clone := *opts
m := &ModelInputProperties{
SQL: inputProps.SQL,
PreExec: fmt.Sprintf("INSTALL 'motherduck'; LOAD 'motherduck'; SET motherduck_token=%s; ATTACH %s;", safeSQLString(token), safeSQLString(inputProps.resolveDSN())),
SQL: inputProps.SQL,
InitQueries: fmt.Sprintf("INSTALL 'motherduck'; LOAD 'motherduck'; SET motherduck_token=%s; ATTACH %s;", safeSQLString(token), safeSQLString(inputProps.resolveDSN())),
}
var props map[string]any
err = mapstructure.Decode(m, &props)
Expand Down
6 changes: 6 additions & 0 deletions runtime/drivers/duckdb/model_executor_self.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ func (e *selfToSelfExecutor) Execute(ctx context.Context, opts *drivers.ModelExe
BeforeCreate: inputProps.PreExec,
AfterCreate: inputProps.PostExec,
}
if inputProps.InitQueries != "" {
createTableOpts.InitQueries = []string{inputProps.InitQueries}
}
res, err := olap.CreateTableAsSelect(ctx, stagingTableName, inputProps.SQL, createTableOpts)
if err != nil {
_ = olap.DropTable(ctx, stagingTableName)
Expand All @@ -160,6 +163,9 @@ func (e *selfToSelfExecutor) Execute(ctx context.Context, opts *drivers.ModelExe
Strategy: outputProps.IncrementalStrategy,
UniqueKey: outputProps.UniqueKey,
}
if inputProps.InitQueries != "" {
insertTableOpts.InitQueries = []string{inputProps.InitQueries}
}
res, err := olap.InsertTableAsSelect(ctx, tableName, inputProps.SQL, insertTableOpts)
if err != nil {
return nil, fmt.Errorf("failed to incrementally insert into table: %w", err)
Expand Down
10 changes: 6 additions & 4 deletions runtime/drivers/duckdb/model_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@ import (
)

type ModelInputProperties struct {
SQL string `mapstructure:"sql"`
Args []any `mapstructure:"args"`
PreExec string `mapstructure:"pre_exec"`
PostExec string `mapstructure:"post_exec"`
SQL string `mapstructure:"sql"`
Args []any `mapstructure:"args"`
// InitQueries are queries that are run during initialisation of write handle before model is created of any pre_exec queries are run.
InitQueries string `mapstructure:"init_queries"`
PreExec string `mapstructure:"pre_exec"`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so is pre_exec run before run every query?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No only before the model query is run.

PostExec string `mapstructure:"post_exec"`
// Database is set if sql is to be run against an external database
Database string `mapstructure:"db"`
}
Expand Down
15 changes: 10 additions & 5 deletions runtime/drivers/duckdb/olap.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ func (c *connection) AddTableColumn(ctx context.Context, tableName, columnName,
_ = release()
}()

_, err = db.MutateTable(ctx, tableName, func(ctx context.Context, conn *sqlx.Conn) error {
_, err = db.MutateTable(ctx, tableName, nil, func(ctx context.Context, conn *sqlx.Conn) error {
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s", safeSQLName(tableName), safeSQLName(columnName), typ))
return err
})
Expand All @@ -212,7 +212,7 @@ func (c *connection) AlterTableColumn(ctx context.Context, tableName, columnName
_ = release()
}()

_, err = db.MutateTable(ctx, tableName, func(ctx context.Context, conn *sqlx.Conn) error {
_, err = db.MutateTable(ctx, tableName, nil, func(ctx context.Context, conn *sqlx.Conn) error {
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ALTER %s TYPE %s", safeSQLName(tableName), safeSQLName(columnName), newType))
return err
})
Expand Down Expand Up @@ -242,7 +242,12 @@ func (c *connection) CreateTableAsSelect(ctx context.Context, name, sql string,
return err
}
}
res, err := db.CreateTableAsSelect(ctx, name, sql, &rduckdb.CreateTableOptions{View: opts.View, BeforeCreateFn: beforeCreateFn, AfterCreateFn: afterCreateFn})
res, err := db.CreateTableAsSelect(ctx, name, sql, &rduckdb.CreateTableOptions{
View: opts.View,
InitQueries: opts.InitQueries,
BeforeCreateFn: beforeCreateFn,
AfterCreateFn: afterCreateFn,
})
if err != nil {
return nil, c.checkErr(err)
}
Expand All @@ -266,7 +271,7 @@ func (c *connection) InsertTableAsSelect(ctx context.Context, name, sql string,
}

if opts.Strategy == drivers.IncrementalStrategyAppend {
res, err := db.MutateTable(ctx, name, func(ctx context.Context, conn *sqlx.Conn) error {
res, err := db.MutateTable(ctx, name, opts.InitQueries, func(ctx context.Context, conn *sqlx.Conn) error {
if opts.BeforeInsert != "" {
_, err := conn.ExecContext(ctx, opts.BeforeInsert)
if err != nil {
Expand All @@ -289,7 +294,7 @@ func (c *connection) InsertTableAsSelect(ctx context.Context, name, sql string,
}

if opts.Strategy == drivers.IncrementalStrategyMerge {
res, err := db.MutateTable(ctx, name, func(ctx context.Context, conn *sqlx.Conn) (mutate error) {
res, err := db.MutateTable(ctx, name, opts.InitQueries, func(ctx context.Context, conn *sqlx.Conn) (mutate error) {
// Execute the pre-init SQL first
if opts.BeforeInsert != "" {
_, err := conn.ExecContext(ctx, opts.BeforeInsert)
Expand Down
2 changes: 2 additions & 0 deletions runtime/drivers/olap.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ type WithConnectionFunc func(wrappedCtx context.Context, ensuredCtx context.Cont

type CreateTableOptions struct {
View bool
InitQueries []string
BeforeCreate string
AfterCreate string
TableOpts map[string]any
Expand All @@ -45,6 +46,7 @@ type TableWriteMetrics struct {
}

type InsertTableOptions struct {
InitQueries []string
BeforeInsert string
AfterInsert string
ByName bool
Expand Down
6 changes: 3 additions & 3 deletions runtime/pkg/rduckdb/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ type DB interface {
CreateTableAsSelect(ctx context.Context, name string, sql string, opts *CreateTableOptions) (*TableWriteMetrics, error)

// MutateTable allows mutating a table in the database by calling the mutateFn.
MutateTable(ctx context.Context, name string, mutateFn func(ctx context.Context, conn *sqlx.Conn) error) (*TableWriteMetrics, error)
MutateTable(ctx context.Context, name string, initQueries []string, mutateFn func(ctx context.Context, conn *sqlx.Conn) error) (*TableWriteMetrics, error)

// DropTable removes a table from the database.
DropTable(ctx context.Context, name string) error
Expand Down Expand Up @@ -455,7 +455,7 @@ func (d *db) CreateTableAsSelect(ctx context.Context, name, query string, opts *
return &TableWriteMetrics{Duration: duration}, nil
}

func (d *db) MutateTable(ctx context.Context, name string, mutateFn func(ctx context.Context, conn *sqlx.Conn) error) (*TableWriteMetrics, error) {
func (d *db) MutateTable(ctx context.Context, name string, initQueries []string, mutateFn func(ctx context.Context, conn *sqlx.Conn) error) (*TableWriteMetrics, error) {
d.logger.Debug("mutate table", zap.String("name", name), observability.ZapCtx(ctx))
err := d.writeSem.Acquire(ctx, 1)
if err != nil {
Expand Down Expand Up @@ -488,7 +488,7 @@ func (d *db) MutateTable(ctx context.Context, name string, mutateFn func(ctx con

// acquire write connection
// need to ignore attaching table since it is already present in the db file
conn, release, err := d.acquireWriteConn(ctx, d.localDBPath(name, newVersion), name, nil, false)
conn, release, err := d.acquireWriteConn(ctx, d.localDBPath(name, newVersion), name, initQueries, false)
if err != nil {
_ = os.RemoveAll(newDir)
return nil, err
Expand Down
8 changes: 4 additions & 4 deletions runtime/pkg/rduckdb/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func TestDB(t *testing.T) {
require.Error(t, err)

// insert into table
_, err = db.MutateTable(ctx, "test2", func(ctx context.Context, conn *sqlx.Conn) error {
_, err = db.MutateTable(ctx, "test2", nil, func(ctx context.Context, conn *sqlx.Conn) error {
_, err := conn.ExecContext(ctx, "INSERT INTO test2 (id, country) VALUES (2, 'USA')")
return err
})
Expand All @@ -58,7 +58,7 @@ func TestDB(t *testing.T) {
require.NoError(t, release())

// Add column
_, err = db.MutateTable(ctx, "test2", func(ctx context.Context, conn *sqlx.Conn) error {
_, err = db.MutateTable(ctx, "test2", nil, func(ctx context.Context, conn *sqlx.Conn) error {
_, err := conn.ExecContext(ctx, "ALTER TABLE test2 ADD COLUMN city TEXT")
return err
})
Expand Down Expand Up @@ -157,7 +157,7 @@ func TestMutateTable(t *testing.T) {
require.NoError(t, err)

// insert into table
_, err = db.MutateTable(ctx, "test", func(ctx context.Context, conn *sqlx.Conn) error {
_, err = db.MutateTable(ctx, "test", nil, func(ctx context.Context, conn *sqlx.Conn) error {
_, err := conn.ExecContext(ctx, "INSERT INTO test (id, city) VALUES (2, 'NY')")
return err
})
Expand All @@ -170,7 +170,7 @@ func TestMutateTable(t *testing.T) {
testDone := make(chan struct{})

go func() {
db.MutateTable(ctx, "test", func(ctx context.Context, conn *sqlx.Conn) error {
db.MutateTable(ctx, "test", nil, func(ctx context.Context, conn *sqlx.Conn) error {
_, err := conn.ExecContext(ctx, "ALTER TABLE test ADD COLUMN country TEXT")
require.NoError(t, err)
_, err = conn.ExecContext(ctx, "UPDATE test SET country = 'USA' WHERE id = 2")
Expand Down
Loading