Skip to content

Commit

Permalink
Runtime : Fix temp_dir usage for local (#5160)
Browse files Browse the repository at this point in the history
* temp_dir fix for local

* fix mapstructure decode

* do not enable external storage when separate file passed

* Apply suggestions from code review

Co-authored-by: Benjamin Egelund-Müller <[email protected]>

---------

Co-authored-by: Benjamin Egelund-Müller <[email protected]>
  • Loading branch information
k-anshul and begelundmuller committed Jun 28, 2024
1 parent 349403e commit 2f3b78a
Show file tree
Hide file tree
Showing 13 changed files with 80 additions and 78 deletions.
47 changes: 13 additions & 34 deletions cli/pkg/local/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,10 @@ import (
"fmt"
"net/http"
"os"
"path"
"path/filepath"
"strconv"
"time"

"github.com/bmatcuk/doublestar/v4"
"github.com/c2h5oh/datasize"
"github.com/rilldata/rill/cli/pkg/browser"
"github.com/rilldata/rill/cli/pkg/cmdutil"
Expand Down Expand Up @@ -203,26 +201,22 @@ func NewApp(ctx context.Context, opts *AppOptions) (*App, error) {

// If the OLAP is the default OLAP (DuckDB in stage.db), we make it relative to the project directory (not the working directory)
defaultOLAP := false
olapDSN := opts.OlapDSN
olapCfg := make(map[string]string)
if opts.OlapDriver == DefaultOLAPDriver && olapDSN == DefaultOLAPDSN {
if opts.OlapDriver == DefaultOLAPDriver && opts.OlapDSN == DefaultOLAPDSN {
defaultOLAP = true
olapDSN = path.Join(dbDirPath, olapDSN)
// Set path which overrides the duckdb's default behaviour to store duckdb data in data_dir/<instance_id>/<connector> directory which is not backward compatible
olapCfg["path"] = olapDSN
val, err := isExternalStorageEnabled(dbDirPath, vars)
val, err := isExternalStorageEnabled(vars)
if err != nil {
return nil, err
}

olapCfg["external_table_storage"] = strconv.FormatBool(val)
}

// Set default DuckDB pool size to 4
olapCfg["dsn"] = olapDSN
if opts.OlapDriver == "duckdb" {
// Set default DuckDB pool size to 4
olapCfg["pool_size"] = "4"
if !defaultOLAP {
// dsn is automatically computed by duckdb driver so we set only when non default dsn is passed
olapCfg["dsn"] = opts.OlapDSN
olapCfg["error_on_incompatible_version"] = "true"
}
}
Expand Down Expand Up @@ -621,27 +615,12 @@ func (s skipFieldZapEncoder) AddString(key, val string) {
}

// isExternalStorageEnabled determines if external storage can be enabled.
// we can't always enable `external_table_storage` if the project dir already has a db file
// it could have been created with older logic where every source was a table in the main db
func isExternalStorageEnabled(dbPath string, variables map[string]string) (bool, error) {
_, err := os.Stat(filepath.Join(dbPath, DefaultOLAPDSN))
if err != nil {
// fresh project
// check if flag explicitly passed
val, ok := variables["connector.duckdb.external_table_storage"]
if !ok {
// mark enabled by default
return true, nil
}
return strconv.ParseBool(val)
}

fsRoot := os.DirFS(dbPath)
glob := path.Clean(path.Join("./", filepath.Join("*", "version.txt")))

matches, err := doublestar.Glob(fsRoot, glob)
if err != nil {
return false, err
}
return len(matches) > 0, nil
func isExternalStorageEnabled(variables map[string]string) (bool, error) {
// check if flag explicitly passed
val, ok := variables["connector.duckdb.external_table_storage"]
if !ok {
// mark enabled by default
return true, nil
}
return strconv.ParseBool(val)
}
2 changes: 1 addition & 1 deletion runtime/drivers/athena/athena.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func (c *Connection) Driver() string {
// Config implements drivers.Connection.
func (c *Connection) Config() map[string]any {
m := make(map[string]any, 0)
_ = mapstructure.Decode(c.config, m)
_ = mapstructure.Decode(c.config, &m)
return m
}

Expand Down
2 changes: 1 addition & 1 deletion runtime/drivers/azure/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ func (c *Connection) Driver() string {
// Config implements drivers.Connection.
func (c *Connection) Config() map[string]any {
m := make(map[string]any, 0)
_ = mapstructure.Decode(c.config, m)
_ = mapstructure.Decode(c.config, &m)
return m
}

Expand Down
3 changes: 2 additions & 1 deletion runtime/drivers/bigquery/bigquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ type driver struct{}
type configProperties struct {
SecretJSON string `mapstructure:"google_application_credentials"`
AllowHostAccess bool `mapstructure:"allow_host_access"`
TempDir string `mapstructure:"temp_dir"`
}

func (d driver) Open(instanceID string, config map[string]any, client *activity.Client, logger *zap.Logger) (drivers.Handle, error) {
Expand Down Expand Up @@ -116,7 +117,7 @@ func (c *Connection) Driver() string {
// Config implements drivers.Connection.
func (c *Connection) Config() map[string]any {
m := make(map[string]any, 0)
_ = mapstructure.Decode(c.config, m)
_ = mapstructure.Decode(c.config, &m)
return m
}

Expand Down
42 changes: 24 additions & 18 deletions runtime/drivers/bigquery/sql_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,11 @@ func (c *Connection) QueryAsFiles(ctx context.Context, props map[string]any, opt
c.logger.Debug("query took", zap.Duration("duration", time.Since(now)), observability.ZapCtx(ctx))
}

tempDir, err := os.MkdirTemp(c.config.TempDir, "bigquery")
if err != nil {
return nil, err
}

p.Target(int64(it.TotalRows), drivers.ProgressUnitRecord)
return &fileIterator{
client: client,
Expand All @@ -158,6 +163,7 @@ func (c *Connection) QueryAsFiles(ctx context.Context, props map[string]any, opt
progress: p,
totalRecords: int64(it.TotalRows),
ctx: ctx,
tempDir: tempDir,
}, nil
}

Expand All @@ -178,17 +184,17 @@ type fileIterator struct {
logger *zap.Logger
limitInBytes int64
progress drivers.Progress
tempDir string

totalRecords int64
tempFilePath string
downloaded bool

ctx context.Context // TODO :: refatcor NextBatch to take context on NextBatch
}

// Close implements drivers.FileIterator.
func (f *fileIterator) Close() error {
return os.Remove(f.tempFilePath)
return os.RemoveAll(f.tempDir)
}

// Next implements drivers.FileIterator.
Expand All @@ -200,10 +206,11 @@ func (f *fileIterator) Next() ([]string, error) {
// storage API not available so can't read as arrow records. Read results row by row and dump in a json file.
if !f.bqIter.IsAccelerated() {
f.logger.Debug("downloading results in json file", observability.ZapCtx(f.ctx))
if err := f.downloadAsJSONFile(); err != nil {
file, err := f.downloadAsJSONFile()
if err != nil {
return nil, err
}
return []string{f.tempFilePath}, nil
return []string{file}, nil
}
f.logger.Debug("downloading results in parquet file", observability.ZapCtx(f.ctx))

Expand All @@ -213,7 +220,6 @@ func (f *fileIterator) Next() ([]string, error) {
return nil, err
}
defer fw.Close()
f.tempFilePath = fw.Name()
f.downloaded = true

rdr, err := f.AsArrowRecordReader()
Expand Down Expand Up @@ -296,19 +302,18 @@ func (f *fileIterator) Format() string {
return ""
}

func (f *fileIterator) downloadAsJSONFile() error {
func (f *fileIterator) downloadAsJSONFile() (string, error) {
tf := time.Now()
defer func() {
f.logger.Debug("time taken to write row in json file", zap.Duration("duration", time.Since(tf)), observability.ZapCtx(f.ctx))
}()

// create a temp file
fw, err := os.CreateTemp("", "temp*.ndjson")
fw, err := os.CreateTemp(f.tempDir, "temp*.ndjson")
if err != nil {
return err
return "", err
}
defer fw.Close()
f.tempFilePath = fw.Name()
f.downloaded = true

init := false
Expand All @@ -320,13 +325,14 @@ func (f *fileIterator) downloadAsJSONFile() error {
row := make(map[string]bigquery.Value)
err := f.bqIter.Next(&row)
if err != nil {
if errors.Is(err, iterator.Done) {
if !init {
return drivers.ErrNoRows
}
return nil
if !errors.Is(err, iterator.Done) {
return "", err
}
if !init {
return "", drivers.ErrNoRows
}
return err
// all rows written successfully
return fw.Name(), nil
}

// schema and total rows is available after first call to next only
Expand Down Expand Up @@ -356,7 +362,7 @@ func (f *fileIterator) downloadAsJSONFile() error {

err = enc.Encode(row)
if err != nil {
return fmt.Errorf("conversion of row to json failed with error: %w", err)
return "", fmt.Errorf("conversion of row to json failed with error: %w", err)
}

// If we don't have storage API access, BigQuery may return massive JSON results. (But even with storage API access, it may return JSON for small results.)
Expand All @@ -365,10 +371,10 @@ func (f *fileIterator) downloadAsJSONFile() error {
if rows != 0 && rows%10000 == 0 { // Check file size every 10k rows
fileInfo, err := os.Stat(fw.Name())
if err != nil {
return fmt.Errorf("bigquery: failed to poll json file size: %w", err)
return "", fmt.Errorf("bigquery: failed to poll json file size: %w", err)
}
if fileInfo.Size() >= _jsonDownloadLimitBytes {
return fmt.Errorf("bigquery: json download exceeded limit of %d bytes (enable and provide access to the BigQuery Storage Read API to read larger results)", _jsonDownloadLimitBytes)
return "", fmt.Errorf("bigquery: json download exceeded limit of %d bytes (enable and provide access to the BigQuery Storage Read API to read larger results)", _jsonDownloadLimitBytes)
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion runtime/drivers/clickhouse/clickhouse.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ func (c *connection) Driver() string {
// Config used to open the Connection
func (c *connection) Config() map[string]any {
m := make(map[string]any, 0)
_ = mapstructure.Decode(c.config, m)
_ = mapstructure.Decode(c.config, &m)
return m
}

Expand Down
2 changes: 1 addition & 1 deletion runtime/drivers/druid/druid.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ func (c *connection) Driver() string {
// Config used to open the Connection
func (c *connection) Config() map[string]any {
m := make(map[string]any, 0)
_ = mapstructure.Decode(c.config, m)
_ = mapstructure.Decode(c.config, &m)
return m
}

Expand Down
2 changes: 1 addition & 1 deletion runtime/drivers/duckdb/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func newConfig(cfgMap map[string]any) (*config, error) {
// Override DSN.Path with config.Path
if cfg.Path != "" { // backward compatibility, cfg.Path takes precedence over cfg.DataDir
uri.Path = cfg.Path
} else if cfg.DataDir != "" {
} else if cfg.DataDir != "" && uri.Path == "" { // if some path is set in DSN, honour that path and ignore DataDir
uri.Path = filepath.Join(cfg.DataDir, "main.db")
}

Expand Down
2 changes: 1 addition & 1 deletion runtime/drivers/file/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ type connection struct {
// Config implements drivers.Connection.
func (c *connection) Config() map[string]any {
m := make(map[string]any, 0)
_ = mapstructure.Decode(c.driverConfig, m)
_ = mapstructure.Decode(c.driverConfig, &m)
return m
}

Expand Down
2 changes: 1 addition & 1 deletion runtime/drivers/gcs/gcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ func (c *Connection) Driver() string {
// Config implements drivers.Connection.
func (c *Connection) Config() map[string]any {
m := make(map[string]any, 0)
_ = mapstructure.Decode(c.config, m)
_ = mapstructure.Decode(c.config, &m)
return m
}

Expand Down
2 changes: 1 addition & 1 deletion runtime/drivers/redshift/redshift.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ func (c *Connection) Driver() string {
// Config implements drivers.Connection.
func (c *Connection) Config() map[string]any {
m := make(map[string]any, 0)
_ = mapstructure.Decode(c.config, m)
_ = mapstructure.Decode(c.config, &m)
return m
}

Expand Down
26 changes: 21 additions & 5 deletions runtime/drivers/snowflake/snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"

"github.com/mitchellh/mapstructure"
"github.com/rilldata/rill/runtime/drivers"
"github.com/rilldata/rill/runtime/pkg/activity"
"go.uber.org/zap"
Expand Down Expand Up @@ -48,14 +49,27 @@ var spec = drivers.Spec{

type driver struct{}

type configProperties struct {
DSN string `mapstructure:"dsn"`
ParallelFetchLimit int `mapstructure:"parallel_fetch_limit"`
TempDir string `mapstructure:"temp_dir"`
}

func (d driver) Open(instanceID string, config map[string]any, client *activity.Client, logger *zap.Logger) (drivers.Handle, error) {
if instanceID == "" {
return nil, errors.New("snowflake driver can't be shared")
}

conf := &configProperties{}
err := mapstructure.WeakDecode(config, conf)
if err != nil {
return nil, err
}

// actual db connection is opened during query
return &connection{
config: config,
logger: logger,
configProperties: conf,
logger: logger,
}, nil
}

Expand All @@ -72,8 +86,8 @@ func (d driver) TertiarySourceConnectors(ctx context.Context, src map[string]any
}

type connection struct {
config map[string]any
logger *zap.Logger
configProperties *configProperties
logger *zap.Logger
}

// Migrate implements drivers.Connection.
Expand All @@ -93,7 +107,9 @@ func (c *connection) Driver() string {

// Config implements drivers.Connection.
func (c *connection) Config() map[string]any {
return c.config
m := make(map[string]any, 0)
_ = mapstructure.Decode(c.configProperties, &m)
return m
}

// Close implements drivers.Connection.
Expand Down
Loading

0 comments on commit 2f3b78a

Please sign in to comment.