Skip to content

Commit

Permalink
Refactor providers and factories
Browse files Browse the repository at this point in the history
  • Loading branch information
paultyng committed Apr 9, 2021
1 parent 265dfc8 commit 5018620
Show file tree
Hide file tree
Showing 14 changed files with 225 additions and 114 deletions.
4 changes: 3 additions & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
"env": {
"TF_ACC": "1",
},
"args": [],
"args": [
"-short",
],
},
// You could pair this configuration with an exec configuration that runs Terraform as
// a compound launch configuration:
Expand Down
11 changes: 6 additions & 5 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
module github.com/paultyng/terraform-provider-sql

go 1.15
go 1.16

require (
github.com/Microsoft/go-winio v0.4.15 // indirect
github.com/containerd/continuity v0.0.0-20200928162600-f2cc35102c2a // indirect
github.com/denisenkom/go-mssqldb v0.9.0
github.com/go-sql-driver/mysql v1.6.0
github.com/google/go-cmp v0.5.4
github.com/hashicorp/go-argmapper v0.0.0-20200721221215-04ae500ede3b
github.com/google/go-cmp v0.5.5
github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/hashicorp/go-argmapper v0.1.1
github.com/hashicorp/go-plugin v1.4.0
github.com/hashicorp/terraform-plugin-docs v0.3.1-0.20210107204619-bf524a84dc08
github.com/hashicorp/terraform-plugin-docs v0.4.0
github.com/hashicorp/terraform-plugin-go v0.2.1
github.com/hashicorp/terraform-plugin-sdk/v2 v2.5.0
github.com/jackc/pgx/v4 v4.11.0
github.com/ory/dockertest/v3 v3.6.3
gopkg.in/yaml.v2 v2.2.8 // indirect
)

// replace github.com/hashicorp/terraform-plugin-go => ../../hashicorp/terraform-plugin-go
// replace github.com/hashicorp/go-argmapper => ../../hashicorp/go-argmapper
60 changes: 34 additions & 26 deletions go.sum

Large diffs are not rendered by default.

10 changes: 7 additions & 3 deletions internal/migration/migration.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@ func Subtract(x, y []Migration) []Migration {
return result
}

func Up(ctx context.Context, db *sql.DB, all, applied []Migration) error {
type SQLExecer interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
}

func Up(ctx context.Context, db SQLExecer, all, applied []Migration) error {
removedMigrations := Subtract(applied, all)
newMigrations := Subtract(all, applied)

Expand All @@ -45,11 +49,11 @@ func Up(ctx context.Context, db *sql.DB, all, applied []Migration) error {
return nil
}

func Down(ctx context.Context, db *sql.DB, all, applied []Migration) error {
func Down(ctx context.Context, db SQLExecer, all, applied []Migration) error {
return runMigrations(ctx, false, applied, execMigration(db))
}

func execMigration(db *sql.DB) func(context.Context, Migration, string) error {
func execMigration(db SQLExecer) func(context.Context, Migration, string) error {
return func(ctx context.Context, m Migration, query string) error {
_, err := db.ExecContext(ctx, query)
if err != nil {
Expand Down
20 changes: 13 additions & 7 deletions internal/provider/data_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,25 @@ import (

"github.com/hashicorp/terraform-plugin-go/tfprotov5"
"github.com/hashicorp/terraform-plugin-go/tfprotov5/tftypes"

"github.com/paultyng/terraform-provider-sql/internal/server"
)

type dataQuery struct {
p *provider
db dbQueryer
p *provider
}

func newDataQuery(p *provider) (*dataQuery, error) {
if p == nil {
return nil, fmt.Errorf("a provider is required")
var _ server.DataSource = (*dataQuery)(nil)

func newDataQuery(db dbQueryer, p *provider) (*dataQuery, error) {
if db == nil {
return nil, fmt.Errorf("a database is required")
}

return &dataQuery{
p: p,
db: db,
p: p,
}, nil
}

Expand Down Expand Up @@ -87,7 +93,7 @@ func (d *dataQuery) Read(ctx context.Context, config map[string]tftypes.Value) (
return nil, nil, err
}

rows, err := d.p.db.QueryContext(ctx, query)
rows, err := d.db.QueryContext(ctx, query)
if err != nil {
return nil, nil, err
}
Expand All @@ -96,7 +102,7 @@ func (d *dataQuery) Read(ctx context.Context, config map[string]tftypes.Value) (
var rowType tftypes.Type
rowSet := []tftypes.Value{}
for rows.Next() {
row, ty, err := d.p.db.valuesForRow(rows)
row, ty, err := d.p.ValuesForRow(rows)
if err != nil {
return nil, []*tfprotov5.Diagnostic{
{
Expand Down
44 changes: 20 additions & 24 deletions internal/provider/db.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package provider

import (
"context"
"database/sql"
"fmt"
"reflect"
Expand All @@ -18,52 +19,47 @@ import (
"github.com/hashicorp/terraform-plugin-go/tfprotov5/tftypes"
)

type db struct {
*sql.DB
type dbQueryer interface {
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
}

driver string
type dbExecer interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
}

func newDB(dsn string, conf func(*sql.DB) error) (*db, error) {
func (p *provider) connect(dsn string) error {
var err error
n := &db{}

scheme, err := schemeFromURL(dsn)
if err != nil {
return nil, err
return err
}

switch scheme {
case "postgres", "postgresql":
n.driver = "pgx"
p.Driver = "pgx"
case "mysql":
n.driver = "mysql"
p.Driver = "mysql"
dsn = strings.TrimPrefix(dsn, "mysql://")
// TODO: multistatements? see go-migrate's implementation
// https://github.com/golang-migrate/migrate/blob/master/database/mysql/mysql.go

// TODO: also set parseTime=true https://github.com/go-sql-driver/mysql#parsetime
case "sqlserver":
n.driver = "sqlserver"
p.Driver = "sqlserver"
default:
return nil, fmt.Errorf("unexpected datasource name scheme: %q", scheme)
return fmt.Errorf("unexpected datasource name scheme: %q", scheme)
}

n.DB, err = sql.Open(n.driver, dsn)
p.DB, err = sql.Open(p.Driver, dsn)
if err != nil {
return nil, fmt.Errorf("unable to open database: %w", err)
return fmt.Errorf("unable to open database: %w", err)
}

// force this to zero, but let callers override config
n.DB.SetMaxIdleConns(0)
if conf != nil {
err = conf(n.DB)
if err != nil {
return nil, fmt.Errorf("unable to configure database: %w", err)
}
}
p.DB.SetMaxIdleConns(0)

return n, nil
return nil
}

func schemeFromURL(url string) (string, error) {
Expand All @@ -81,7 +77,7 @@ func schemeFromURL(url string) (string, error) {
return url[0:i], nil
}

func (db *db) valuesForRow(rows *sql.Rows) (map[string]tftypes.Value, map[string]tftypes.Type, error) {
func (p *provider) ValuesForRow(rows *sql.Rows) (map[string]tftypes.Value, map[string]tftypes.Type, error) {
colTypes, err := rows.ColumnTypes()
if err != nil {
return nil, nil, fmt.Errorf("unable to retrieve column type: %w", err)
Expand All @@ -100,7 +96,7 @@ func (db *db) valuesForRow(rows *sql.Rows) (map[string]tftypes.Value, map[string
name = fmt.Sprintf("column%d", i)
}

ty, rty, err := db.typeAndValueForColType(colType)
ty, rty, err := p.typeAndValueForColType(colType)
if err != nil {
return nil, nil, fmt.Errorf("unable to determine type for %q: %w", name, err)
}
Expand Down Expand Up @@ -176,11 +172,11 @@ func (db *db) valuesForRow(rows *sql.Rows) (map[string]tftypes.Value, map[string
return rowValues, rowTypes, nil
}

func (db *db) typeAndValueForColType(colType *sql.ColumnType) (tftypes.Type, reflect.Type, error) {
func (p *provider) typeAndValueForColType(colType *sql.ColumnType) (tftypes.Type, reflect.Type, error) {
scanType := colType.ScanType()
kind := scanType.Kind()

switch db.driver {
switch p.Driver {
case "sqlserver":
switch dbName := colType.DatabaseTypeName(); dbName {
case "UNIQUEIDENTIFIER":
Expand Down
19 changes: 12 additions & 7 deletions internal/provider/server_test.go → internal/provider/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ func runTestMain(m *testing.M) int {

flag.Parse()

if testing.Short() {
return m.Run()
}

// remove unspecified test drivers
serverNames := strings.Split(*rawServerNames, ",")
for i := len(testServers) - 1; i >= 0; i-- {
Expand Down Expand Up @@ -198,13 +202,14 @@ func (td *testServer) Start() error {
}

td.resourceOnceErr = dockerPool.Retry(func() error {
db, err := newDB(td.url, nil)
p := &provider{}
err := p.connect(td.url)
if err != nil {
return err
}
defer db.Close()
defer p.DB.Close()

err = db.Ping()
err = p.DB.Ping()
if err != nil {
return err
}
Expand All @@ -216,14 +221,14 @@ func (td *testServer) Start() error {
}

if td.OnReady != nil {
var db *db
db, td.resourceOnceErr = newDB(td.url, nil)
p := &provider{}
td.resourceOnceErr = p.connect(td.url)
if td.resourceOnceErr != nil {
return
}
defer db.Close()
defer p.DB.Close()

td.resourceOnceErr = td.OnReady(db.DB)
td.resourceOnceErr = td.OnReady(p.DB)
if td.resourceOnceErr != nil {
return
}
Expand Down
39 changes: 20 additions & 19 deletions internal/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ func New(version string) func() tfprotov5.ProviderServer {
}

type provider struct {
db *db
DB *sql.DB `argmapper:",typeOnly"`

Driver string
}

var _ server.Provider = (*provider)(nil)
Expand Down Expand Up @@ -76,9 +78,9 @@ func (p *provider) Validate(ctx context.Context, config map[string]tftypes.Value
}

func (p *provider) Configure(ctx context.Context, config map[string]tftypes.Value) ([]*tfprotov5.Diagnostic, error) {
if p.db != nil {
if p.DB != nil {
// if reconfiguring, close existing connection
_ = p.db.Close()
_ = p.DB.Close()
}

var err error
Expand Down Expand Up @@ -132,26 +134,25 @@ func (p *provider) Configure(ctx context.Context, config map[string]tftypes.Valu
}
}

p.db, err = newDB(url, func(db *sql.DB) error {
maxOpen, acc := maxOpenConns.Int64()
if acc != big.Exact {
return fmt.Errorf("ConfigureProvider - results for max_open_conns is not exact")
}

maxIdle, acc := maxIdleConns.Int64()
if acc != big.Exact {
return fmt.Errorf("ConfigureProvider - results for max_open_conns is not exact")
}

db.SetMaxOpenConns(int(maxOpen))
db.SetMaxIdleConns(int(maxIdle))
return nil
})
err = p.connect(url)
if err != nil {
return nil, fmt.Errorf("ConfigureProvider - unable to open database: %w", err)
}

err = p.db.PingContext(ctx)
maxOpen, acc := maxOpenConns.Int64()
if acc != big.Exact {
return nil, fmt.Errorf("ConfigureProvider - results for max_open_conns is not exact")
}

maxIdle, acc := maxIdleConns.Int64()
if acc != big.Exact {
return nil, fmt.Errorf("ConfigureProvider - results for max_open_conns is not exact")
}

p.DB.SetMaxOpenConns(int(maxOpen))
p.DB.SetMaxIdleConns(int(maxIdle))

err = p.DB.PingContext(ctx)
if err != nil {
return nil, fmt.Errorf("ConfigureProvider - unable to ping database: %w", err)
}
Expand Down
12 changes: 12 additions & 0 deletions internal/provider/provider_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package provider

import (
"testing"
// "github.com/paultyng/terraform-provider-sql/internal/server"
)

func TestServer(t *testing.T) {
_ = New("acctest")()

// s.Test(t)
}
7 changes: 5 additions & 2 deletions internal/provider/resource_migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@ type resourceMigrate struct {
resourceMigrateCommon
}

func newResourceMigrate(p *provider) (*resourceMigrate, error) {
var _ server.Resource = (*resourceMigrate)(nil)
var _ server.ResourceUpdater = (*resourceMigrate)(nil)

func newResourceMigrate(db dbExecer) (*resourceMigrate, error) {
return &resourceMigrate{
resourceMigrateCommon: resourceMigrateCommon{
p: p,
db: db,
},
}, nil
}
Expand Down
Loading

0 comments on commit 5018620

Please sign in to comment.