Skip to content

Commit

Permalink
Fix some data typing issues
Browse files Browse the repository at this point in the history
  • Loading branch information
paultyng committed Nov 18, 2020
1 parent 9b5aec0 commit 87c1c74
Show file tree
Hide file tree
Showing 5 changed files with 167 additions and 113 deletions.
62 changes: 36 additions & 26 deletions internal/provider/data_query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,19 @@ import (
"strings"
"testing"

"github.com/hashicorp/terraform-plugin-go/tfprotov5"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource"
helperresource "github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource"
"github.com/hashicorp/terraform-plugin-sdk/v2/terraform"
)

func TestQuery_driverTypes(t *testing.T) {
const testColName = "testcol"

func TestDataQuery_driverTypes(t *testing.T) {
if testing.Short() {
t.Skip("skipping long test")
}

const testColName = "testcol"

for k, url := range testURLs {
// TODO: check nulls for all these
t.Run(k, func(t *testing.T) {
scheme, err := schemeFromURL(url)
if err != nil {
Expand All @@ -37,20 +37,33 @@ func TestQuery_driverTypes(t *testing.T) {
// https://dev.mysql.com/doc/refman/8.0/en/cast-functions.html#function_convert

// TODO: "binary"
"char": {"cast('foo' as char)", "foo"},
"date": {"cast('2020-11-16' as date)", "2020-11-16T00:00:00Z"},
"datetime": {"cast('2020-11-16 19:00:01' as datetime)", "2020-11-16T19:00:01Z"},
"decimal": {"cast(1.2 as decimal(4,3))", ""},
"double": {"cast(1.2 as double)", "1.2"},
"float": {"cast(.125 as float(5))", "0.125"},
"char": {"cast('foo' as char)", "foo"},
"char null": {"cast(null as char)", ""},
"date": {"cast('2020-11-16' as date)", "2020-11-16T00:00:00Z"},
"date null": {"cast(null as date)", ""},
"datetime": {"cast('2020-11-16 19:00:01' as datetime)", "2020-11-16T19:00:01Z"},
"datetime null": {"cast(null as datetime)", ""},
"decimal": {"cast(1.2 as decimal(4,3))", ""},
"decimal null": {"cast(null as decimal)", ""},
"double": {"cast(1.2 as double)", "1.2"},
"double null": {"cast(null as double)", ""},
"float": {"cast(.125 as float(5))", "0.125"},
"float null": {"cast(null as float)", ""},
// TODO: parse to HCL types
"json": {"JSON_TYPE('[1, 2, 3]')", ""},
"nchar": {"cast('foo' as nchar)", "foo"},
"real": {"cast(.125 as real)", "0.125"},
"signed": {"cast(-7 as signed)", "-7"},
"time": {"cast('04:05:06' as time)", "04:05:06"},
"unsigned": {"cast(1 as unsigned)", "1"},
"year": {"cast(2020 as year)", "2020"},
"json": {"JSON_TYPE('[1, 2, 3]')", ""},
"json null": {"cast(null as json)", ""},
"nchar": {"cast('foo' as nchar)", "foo"},
"nchar null": {"cast(null as nchar)", ""},
"real": {"cast(.125 as real)", "0.125"},
"real null": {"cast(null as real)", ""},
"signed": {"cast(-7 as signed)", "-7"},
"signed null": {"cast(null as signed)", ""},
"time": {"cast('04:05:06' as time)", "04:05:06"},
"time null": {"cast(null as time)", ""},
"unsigned": {"cast(1 as unsigned)", "1"},
"unsigned null": {"cast(null as unsigned)", ""},
"year": {"cast(2020 as year)", "2020"},
"year null": {"cast(null as year)", ""},
}
case "postgres":
literals = map[string]struct {
Expand Down Expand Up @@ -112,6 +125,7 @@ func TestQuery_driverTypes(t *testing.T) {
delete(literals, "cidr")
delete(literals, "macaddr")
delete(literals, "macaddr8")
delete(literals, "money")
delete(literals, "time with time zone")
delete(literals, "timestamp with time zone")
delete(literals, "xml")
Expand Down Expand Up @@ -180,13 +194,9 @@ func TestQuery_driverTypes(t *testing.T) {
// fix slash escaping
col := strings.ReplaceAll(lit.sql, `\`, `\\`)
query := fmt.Sprintf("select %s as %s", col, testColName)
resource.UnitTest(t, resource.TestCase{
ProtoV5ProviderFactories: map[string]func() (tfprotov5.ProviderServer, error){
"sql": func() (tfprotov5.ProviderServer, error) {
return New("acctest")(), nil
},
},
Steps: []resource.TestStep{
helperresource.UnitTest(t, helperresource.TestCase{
ProtoV5ProviderFactories: protoV5ProviderFactories,
Steps: []helperresource.TestStep{
{

Config: fmt.Sprintf(`
Expand All @@ -204,7 +214,7 @@ output "query" {
value = data.sql_query.test.result
}
`, url, query),
Check: resource.ComposeTestCheckFunc(
Check: helperresource.ComposeTestCheckFunc(
func(s *terraform.State) error {
rs := s.RootModule().Resources["data.sql_query.test"]
att := rs.Primary.Attributes["result.0."+testColName]
Expand Down
75 changes: 67 additions & 8 deletions internal/provider/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"reflect"
"strings"
"time"

// database drivers
_ "github.com/denisenkom/go-mssqldb"
Expand Down Expand Up @@ -122,9 +123,52 @@ func (db *db) valuesForRow(rows *sql.Rows) (map[string]tftypes.Value, map[string
rowValues := map[string]tftypes.Value{}
rowTypes := map[string]tftypes.Type{}
for k, v := range row {
val := v.val

// unwrap sql types
switch tv := val.(type) {
case *sql.NullInt64:
if !tv.Valid {
val = nil
} else {
val = &tv.Int64
}
case *sql.NullInt32:
if !tv.Valid {
val = nil
} else {
val = &tv.Int32
}
case *sql.NullFloat64:
if !tv.Valid {
val = nil
} else {
val = &tv.Float64
}
case *sql.NullBool:
if !tv.Valid {
val = nil
} else {
val = &tv.Bool
}
case *sql.NullString:
if !tv.Valid {
val = nil
} else {
val = &tv.String
}
case *sql.NullTime:
if !tv.Valid {
val = nil
} else {
s := tv.Time.Format(time.RFC3339)
val = &s
}
}

rowValues[k] = tftypes.NewValue(
v.ty,
v.val,
val,
)
rowTypes[k] = v.ty
}
Expand All @@ -142,27 +186,42 @@ func (db *db) typeAndValueForColType(colType *sql.ColumnType) (tftypes.Type, ref
case "UNIQUEIDENTIFIER":
return tftypes.String, reflect.TypeOf((*sqlServerUniqueIdentifier)(nil)).Elem(), nil
case "DECIMAL", "MONEY", "SMALLMONEY":
return tftypes.String, reflect.TypeOf((*string)(nil)).Elem(), nil
// TODO: add diags about converting to numeric?
return tftypes.String, reflect.TypeOf((*sql.NullString)(nil)).Elem(), nil
}
case "mysql":
switch dbName := colType.DatabaseTypeName(); dbName {
case "YEAR":
return tftypes.Number, reflect.TypeOf((*int)(nil)).Elem(), nil
case "VARCHAR", "DECIMAL", "TIME":
return tftypes.String, reflect.TypeOf((*string)(nil)).Elem(), nil
return tftypes.Number, reflect.TypeOf((*sql.NullInt32)(nil)).Elem(), nil
case "VARCHAR", "DECIMAL", "TIME", "JSON":
return tftypes.String, reflect.TypeOf((*sql.NullString)(nil)).Elem(), nil
case "DATE", "DATETIME":
return tftypes.String, tfTimeType, nil
return tftypes.String, reflect.TypeOf((*sql.NullTime)(nil)).Elem(), nil
}
case "pgx":
switch dbName := colType.DatabaseTypeName(); dbName {
// 790 is the oid of money
case "MONEY", "790":
return nil, nil, fmt.Errorf("money is not supported for column %q, please convert to numeric", colType.Name())
// TODO: add diags about converting to numeric?
return tftypes.String, reflect.TypeOf((*sql.NullString)(nil)).Elem(), nil
case "TIMESTAMPTZ", "TIMESTAMP", "DATE":
return tftypes.String, tfTimeType, nil
return tftypes.String, reflect.TypeOf((*sql.NullTime)(nil)).Elem(), nil
}
}

switch scanType {
case reflect.TypeOf((*sql.NullInt64)(nil)).Elem(),
reflect.TypeOf((*sql.NullInt32)(nil)).Elem(),
reflect.TypeOf((*sql.NullFloat64)(nil)).Elem():
return tftypes.Number, scanType, nil
case reflect.TypeOf((*sql.NullString)(nil)).Elem():
return tftypes.String, scanType, nil
case reflect.TypeOf((*sql.NullBool)(nil)).Elem():
return tftypes.Bool, scanType, nil
case reflect.TypeOf((*sql.NullTime)(nil)).Elem():
return tftypes.String, scanType, nil
}

switch kind {
case reflect.String:
return tftypes.String, scanType, nil
Expand Down
7 changes: 3 additions & 4 deletions internal/provider/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ func (p *provider) GetProviderSchema(ctx context.Context, req *tfprotov5.GetProv
for _, typeName := range []string{"sql_query"} {
ds, err := p.NewDataSource(typeName)
if err != nil {
// TODO: diags?
return nil, err
}
resp.DataSourceSchemas[typeName] = ds.Schema(ctx)
Expand Down Expand Up @@ -237,7 +236,6 @@ func (p *provider) ImportResourceState(ctx context.Context, req *tfprotov5.Impor
func (p *provider) ValidateDataSourceConfig(ctx context.Context, req *tfprotov5.ValidateDataSourceConfigRequest) (*tfprotov5.ValidateDataSourceConfigResponse, error) {
ds, err := p.NewDataSource(req.TypeName)
if err != nil {
// TODO: diags?
return nil, err
}

Expand All @@ -258,6 +256,7 @@ func (p *provider) ValidateDataSourceConfig(ctx context.Context, req *tfprotov5.
if err != nil {
return nil, err
}

return &tfprotov5.ValidateDataSourceConfigResponse{
Diagnostics: diags,
}, nil
Expand Down Expand Up @@ -303,13 +302,13 @@ func (p *provider) ReadDataSource(ctx context.Context, req *tfprotov5.ReadDataSo
}

// TODO: should NewDynamicValue return a pointer?
stateObject, err := tfprotov5.NewDynamicValue(schemaObjectType, tftypes.NewValue(schemaObjectType, state))
stateValue, err := tfprotov5.NewDynamicValue(schemaObjectType, tftypes.NewValue(schemaObjectType, state))
if err != nil {
return nil, fmt.Errorf("error NewDynamicValue: %w", err)
}

return &tfprotov5.ReadDataSourceResponse{
State: &stateObject,
State: &stateValue,
Diagnostics: diags,
}, nil
}
Loading

0 comments on commit 87c1c74

Please sign in to comment.