diff --git a/runtime/drivers/druid/druidsqldriver/druid_api_sql_driver.go b/runtime/drivers/druid/druidsqldriver/druid_api_sql_driver.go index e8f70266d2d..767ca47c873 100644 --- a/runtime/drivers/druid/druidsqldriver/druid_api_sql_driver.go +++ b/runtime/drivers/druid/druidsqldriver/druid_api_sql_driver.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "net/http" + "reflect" "time" "github.com/google/uuid" @@ -97,7 +98,44 @@ func (c *sqlConnection) QueryContext(ctx context.Context, query string, args []d transformers := make([]func(any) (any, error), len(columns)) for i, c := range types { transformers[i] = identityTransformer - if c == "TIMESTAMP" { + switch c { + case "TINYINT": + transformers[i] = func(v any) (any, error) { + switch v := v.(type) { + case float64: + return int8(v), nil + default: + return v, nil + } + } + case "SMALLINT": + transformers[i] = func(v any) (any, error) { + switch v := v.(type) { + case float64: + return int16(v), nil + default: + return v, nil + } + } + case "INTEGER": + transformers[i] = func(v any) (any, error) { + switch v := v.(type) { + case float64: + return int32(v), nil + default: + return v, nil + } + } + case "BIGINT": + transformers[i] = func(v any) (any, error) { + switch v := v.(type) { + case float64: + return int64(v), nil + default: + return v, nil + } + } + case "TIMESTAMP": transformers[i] = func(v any) (any, error) { t, err := time.Parse(time.RFC3339, v.(string)) if err != nil { @@ -106,7 +144,7 @@ func (c *sqlConnection) QueryContext(ctx context.Context, query string, args []d return t, nil } - } else if c == "ARRAY" { + case "ARRAY": transformers[i] = func(v any) (any, error) { var l []any err := json.Unmarshal([]byte(v.(string)), &l) @@ -115,7 +153,7 @@ func (c *sqlConnection) QueryContext(ctx context.Context, query string, args []d } return l, nil } - } else if c == "OTHER" { + case "OTHER": transformers[i] = func(v any) (any, error) { var l map[string]any err := json.Unmarshal([]byte(v.(string)), &l) @@ -182,6 +220,40 @@ func (dr *druidRows) Next(dest []driver.Value) error { return nil } +func (dr *druidRows) ColumnTypeScanType(index int) reflect.Type { + switch dr.types[index] { + case "BOOLEAN": + return reflect.TypeOf(true) + case "TINYINT": + return reflect.TypeOf(int8(0)) + case "SMALLINT": + return reflect.TypeOf(int16(0)) + case "INTEGER": + return reflect.TypeOf(int32(0)) + case "BIGINT": + return reflect.TypeOf(int64(0)) + case "FLOAT": + return reflect.TypeOf(float32(0)) + case "DOUBLE": + return reflect.TypeOf(float64(0)) + case "REAL": + return reflect.TypeOf(float64(0)) + case "DECIMAL": + return reflect.TypeOf(float64(0)) + case "CHAR": + return reflect.TypeOf("") + case "VARCHAR": + return reflect.TypeOf("") + case "TIMESTAMP": + return reflect.TypeOf(time.Time{}) + case "DATE": + return reflect.TypeOf(time.Time{}) + case "OTHER": + return reflect.TypeOf("") + } + return nil +} + func (dr *druidRows) ColumnTypeDatabaseTypeName(index int) string { return dr.types[index] } diff --git a/runtime/queries/metricsview_aggregation.go b/runtime/queries/metricsview_aggregation.go index 4255afdc5ff..945a57ad9b1 100644 --- a/runtime/queries/metricsview_aggregation.go +++ b/runtime/queries/metricsview_aggregation.go @@ -262,7 +262,7 @@ func (q *MetricsViewAggregation) pivotDruid(ctx context.Context, rows *drivers.R } err = appender.AppendRow(appendValues...) if err != nil { - return err + return fmt.Errorf("duckdb append failed: %w", err) } count++ if count > maxCount {