From 31781fc077e8e444efdbb22d54abb692f0caef0c Mon Sep 17 00:00:00 2001 From: Michael Chinn Date: Mon, 21 Aug 2023 22:36:50 -0400 Subject: [PATCH 1/2] fix: use native ScanType from driver and enhance RowBuffer to understand more types --- dump.go | 129 +++++++++----- dump_test.go | 10 +- mysqldump.go | 2 +- mysqldump_test.go | 423 ++++++++++++++++++++++++++++++++++++++++++---- 4 files changed, 488 insertions(+), 76 deletions(-) diff --git a/dump.go b/dump.go index be4abfb..5109f50 100644 --- a/dump.go +++ b/dump.go @@ -16,11 +16,11 @@ import ( /* Data struct to configure dump behavior - Out: Stream to wite to - Connection: Database connection to dump - IgnoreTables: Mark sensitive tables to ignore - MaxAllowedPacket: Sets the largest packet size to use in backups - LockTables: Lock all tables for the duration of the dump + Out: Stream to wite to + Connection: Database connection to dump + IgnoreTables: Mark sensitive tables to ignore + MaxAllowedPacket: Sets the largest packet size to use in backups + LockTables: Lock all tables for the duration of the dump */ type Data struct { Out io.Writer @@ -68,7 +68,7 @@ const headerTmpl = `-- Go SQL Dump {{ .DumpVersion }} /*!40101 SET @OLD_CHARACTER_SET_CLIENT=@@CHARACTER_SET_CLIENT */; /*!40101 SET @OLD_CHARACTER_SET_RESULTS=@@CHARACTER_SET_RESULTS */; /*!40101 SET @OLD_COLLATION_CONNECTION=@@COLLATION_CONNECTION */; - SET NAMES utf8mb4 ; +/*!50503 SET NAMES UTF8 */; /*!40103 SET @OLD_TIME_ZONE=@@TIME_ZONE */; /*!40103 SET TIME_ZONE='+00:00' */; /*!40014 SET @OLD_UNIQUE_CHECKS=@@UNIQUE_CHECKS, UNIQUE_CHECKS=0 */; @@ -99,7 +99,7 @@ const tableTmpl = ` DROP TABLE IF EXISTS {{ .NameEsc }}; /*!40101 SET @saved_cs_client = @@character_set_client */; - SET character_set_client = utf8mb4 ; +/*!50503 SET character_set_client = utf8mb4 */; {{ .CreateSQL }}; /*!40101 SET character_set_client = @saved_cs_client */; @@ -296,7 +296,7 @@ func (table *table) CreateSQL() (string, error) { } if tableReturn.String != table.Name { - return "", errors.New("Returned table is not the same as requested table") + return "", errors.New("returned table is not the same as requested table") } return tableSQL.String, nil @@ -383,38 +383,11 @@ func (table *table) Init() error { table.values = make([]interface{}, len(tt)) for i, tp := range tt { - table.values[i] = reflect.New(reflectColumnType(tp)).Interface() + table.values[i] = reflect.New(tp.ScanType()).Interface() } return nil } -func reflectColumnType(tp *sql.ColumnType) reflect.Type { - // reflect for scanable - switch tp.ScanType().Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return reflect.TypeOf(sql.NullInt64{}) - case reflect.Float32, reflect.Float64: - return reflect.TypeOf(sql.NullFloat64{}) - case reflect.String: - return reflect.TypeOf(sql.NullString{}) - } - - // determine by name - switch tp.DatabaseTypeName() { - case "BLOB", "BINARY": - return reflect.TypeOf(sql.RawBytes{}) - case "VARCHAR", "TEXT", "DECIMAL", "JSON": - return reflect.TypeOf(sql.NullString{}) - case "BIGINT", "TINYINT", "INT": - return reflect.TypeOf(sql.NullInt64{}) - case "DOUBLE": - return reflect.TypeOf(sql.NullFloat64{}) - } - - // unknown datatype - return tp.ScanType() -} - func (table *table) Next() bool { if table.rows == nil { if err := table.Init(); err != nil { @@ -443,6 +416,30 @@ func (table *table) RowValues() string { return table.RowBuffer().String() } +func writeString(b *bytes.Buffer, s string) { + fmt.Fprintf(b, "'%s'", sanitize(s)) +} + +func writeBool(b *bytes.Buffer, s bool) { + if s { + fmt.Fprintf(b, "1") + } else { + fmt.Fprintf(b, "0") + } +} + +func writeBinary(b *bytes.Buffer, s []byte) { + if len(s) == 0 { + b.WriteString(nullType) + } else { + fmt.Fprintf(b, "_binary '%s'", sanitize(string(s))) + } +} + +func writeTime(b *bytes.Buffer, s time.Time) { + fmt.Fprintf(b, "'%s'", sanitize(s.UTC().Format(time.DateTime))) +} + func (table *table) RowBuffer() *bytes.Buffer { var b bytes.Buffer b.WriteString("(") @@ -454,9 +451,51 @@ func (table *table) RowBuffer() *bytes.Buffer { switch s := value.(type) { case nil: b.WriteString(nullType) + case *string: + writeString(&b, *s) case *sql.NullString: if s.Valid { - fmt.Fprintf(&b, "'%s'", sanitize(s.String)) + writeString(&b, s.String) + } else { + b.WriteString(nullType) + } + case *bool: + writeBool(&b, *s) + case *sql.NullBool: + if s.Valid { + writeBool(&b, s.Bool) + } else { + b.WriteString(nullType) + } + case *uint: + fmt.Fprintf(&b, "%d", *s) + case *uint8: + fmt.Fprintf(&b, "%d", *s) + case *uint16: + fmt.Fprintf(&b, "%d", *s) + case *uint32: + fmt.Fprintf(&b, "%d", *s) + case *uint64: + fmt.Fprintf(&b, "%d", *s) + case *int: + fmt.Fprintf(&b, "%d", *s) + case *int8: + fmt.Fprintf(&b, "%d", *s) + case *int16: + fmt.Fprintf(&b, "%d", *s) + case *int32: + fmt.Fprintf(&b, "%d", *s) + case *int64: + fmt.Fprintf(&b, "%d", *s) + case *sql.NullInt16: + if s.Valid { + fmt.Fprintf(&b, "%d", s.Int16) + } else { + b.WriteString(nullType) + } + case *sql.NullInt32: + if s.Valid { + fmt.Fprintf(&b, "%d", s.Int32) } else { b.WriteString(nullType) } @@ -466,17 +505,27 @@ func (table *table) RowBuffer() *bytes.Buffer { } else { b.WriteString(nullType) } + case *float32: + fmt.Fprintf(&b, "%f", *s) + case *float64: + fmt.Fprintf(&b, "%f", *s) case *sql.NullFloat64: if s.Valid { fmt.Fprintf(&b, "%f", s.Float64) } else { b.WriteString(nullType) } + case *[]byte: + writeBinary(&b, *s) case *sql.RawBytes: - if len(*s) == 0 { - b.WriteString(nullType) + writeBinary(&b, *s) + case *time.Time: + writeTime(&b, *s) + case *sql.NullTime: + if s.Valid { + writeTime(&b, s.Time) } else { - fmt.Fprintf(&b, "_binary '%s'", sanitize(string(*s))) + b.WriteString(nullType) } default: fmt.Fprintf(&b, "'%s'", value) diff --git a/dump_test.go b/dump_test.go index 486901b..b9abb5b 100644 --- a/dump_test.go +++ b/dump_test.go @@ -228,7 +228,7 @@ func TestCreateTableAllValuesWithNil(t *testing.T) { AddRow("email", ""). AddRow("name", "") - rows := sqlmock.NewRowsWithColumnDefinition(c("id", 0), c("email", ""), c("name", "")). + rows := sqlmock.NewRowsWithColumnDefinition(c("id", 0), c("email", sql.NullString{}), c("name", "")). AddRow(1, nil, "Test Name 1"). AddRow(2, "test2@test.de", "Test Name 2"). AddRow(3, "", "Test Name 3") @@ -266,7 +266,7 @@ func TestCreateTableOk(t *testing.T) { AddRow("email", ""). AddRow("name", "") - createTableValueRows := sqlmock.NewRowsWithColumnDefinition(c("id", 0), c("email", ""), c("name", "")). + createTableValueRows := sqlmock.NewRowsWithColumnDefinition(c("id", 0), c("email", sql.NullString{}), c("name", "")). AddRow(1, nil, "Test Name 1"). AddRow(2, "test2@test.de", "Test Name 2") @@ -294,7 +294,7 @@ func TestCreateTableOk(t *testing.T) { DROP TABLE IF EXISTS ~Test_Table~; /*!40101 SET @saved_cs_client = @@character_set_client */; - SET character_set_client = utf8mb4 ; +/*!50503 SET character_set_client = utf8mb4 */; CREATE TABLE 'Test_Table' (~id~ int(11) NOT NULL AUTO_INCREMENT,~s~ char(60) DEFAULT NULL, PRIMARY KEY (~id~))ENGINE=InnoDB DEFAULT CHARSET=latin1; /*!40101 SET character_set_client = @saved_cs_client */; @@ -325,7 +325,7 @@ func TestCreateTableOkSmallPackets(t *testing.T) { AddRow("email", ""). AddRow("name", "") - createTableValueRows := sqlmock.NewRowsWithColumnDefinition(c("id", 0), c("email", ""), c("name", "")). + createTableValueRows := sqlmock.NewRowsWithColumnDefinition(c("id", 0), c("email", sql.NullString{}), c("name", "")). AddRow(1, nil, "Test Name 1"). AddRow(2, "test2@test.de", "Test Name 2") @@ -353,7 +353,7 @@ func TestCreateTableOkSmallPackets(t *testing.T) { DROP TABLE IF EXISTS ~Test_Table~; /*!40101 SET @saved_cs_client = @@character_set_client */; - SET character_set_client = utf8mb4 ; +/*!50503 SET character_set_client = utf8mb4 */; CREATE TABLE 'Test_Table' (~id~ int(11) NOT NULL AUTO_INCREMENT,~s~ char(60) DEFAULT NULL, PRIMARY KEY (~id~))ENGINE=InnoDB DEFAULT CHARSET=latin1; /*!40101 SET character_set_client = @saved_cs_client */; diff --git a/mysqldump.go b/mysqldump.go index 258da8c..6e9f58a 100644 --- a/mysqldump.go +++ b/mysqldump.go @@ -18,7 +18,7 @@ Register a new dumper. */ func Register(db *sql.DB, dir, format string) (*Data, error) { if !isDir(dir) { - return nil, errors.New("Invalid directory") + return nil, errors.New("invalid directory") } name := time.Now().Format(format) diff --git a/mysqldump_test.go b/mysqldump_test.go index 1250707..ad62f59 100644 --- a/mysqldump_test.go +++ b/mysqldump_test.go @@ -2,14 +2,17 @@ package mysqldump_test import ( "bytes" - "io/ioutil" - "reflect" + "database/sql" + "fmt" + "io" "strings" "testing" + "time" sqlmock "github.com/DATA-DOG/go-sqlmock" - "github.com/jamf/go-mysqldump" "github.com/stretchr/testify/assert" + + "github.com/jamf/go-mysqldump" ) const expected = `-- Go SQL Dump ` + mysqldump.Version + ` @@ -20,7 +23,7 @@ const expected = `-- Go SQL Dump ` + mysqldump.Version + ` /*!40101 SET @OLD_CHARACTER_SET_CLIENT=@@CHARACTER_SET_CLIENT */; /*!40101 SET @OLD_CHARACTER_SET_RESULTS=@@CHARACTER_SET_RESULTS */; /*!40101 SET @OLD_COLLATION_CONNECTION=@@COLLATION_CONNECTION */; - SET NAMES utf8mb4 ; +/*!50503 SET NAMES UTF8 */; /*!40103 SET @OLD_TIME_ZONE=@@TIME_ZONE */; /*!40103 SET TIME_ZONE='+00:00' */; /*!40014 SET @OLD_UNIQUE_CHECKS=@@UNIQUE_CHECKS, UNIQUE_CHECKS=0 */; @@ -34,8 +37,40 @@ const expected = `-- Go SQL Dump ` + mysqldump.Version + ` DROP TABLE IF EXISTS ~Test_Table~; /*!40101 SET @saved_cs_client = @@character_set_client */; - SET character_set_client = utf8mb4 ; -CREATE TABLE 'Test_Table' (~id~ int(11) NOT NULL AUTO_INCREMENT,~email~ char(60) DEFAULT NULL, ~name~ char(60), PRIMARY KEY (~id~))ENGINE=InnoDB DEFAULT CHARSET=latin1; +/*!50503 SET character_set_client = utf8mb4 */; +CREATE TABLE 'Test_Table' ( + ~id~ int(11) NOT NULL AUTO_INCREMENT, + ~email~ varchar(255) NOT NULL, + ~given_name~ varchar(127) NOT NULL DEFAULT '', + ~surname~ varchar(127) NOT NULL DEFAULT '', + ~name~ varchar(255) GENERATED ALWAYS AS (CONCAT(given_name,' ',surname)), + ~int8~ TINYINT NOT NULL, + ~NullInt8~ TINYINT, + ~uint8~ TINYINT UNSIGNED NOT NULL, + ~NullUint8~ TINYINT UNSIGNED, + ~int16~ SMALLINT NOT NULL, + ~NullInt16~ SMALLINT, + ~uint16~ SMALLINT UNSIGNED NOT NULL, + ~NullUint16~ SMALLINT UNSIGNED, + ~int32~ INT(11) NOT NULL, + ~NullInt32~ INT(11), + ~uint32~ INT(11) UNSIGNED NOT NULL, + ~NullUint32~ INT(11) UNSIGNED, + ~int64~ BIGINT NOT NULL, + ~NullInt64~ BIGINT, + ~uint64~ BIGINT UNSIGNED NOT NULL, + ~float32~ FLOAT NOT NULL, + ~NullFloat32~ FLOAT, + ~float64~ DOUBLE NOT NULL, + ~NullFloat64~ DOUBLE, + ~bool~ TINYINT(1) NOT NULL, + ~NullBool~ TINYINT(1), + ~time~ TIME NOT NULL, + ~NullTime~ TIME, + ~varbinary~ VARBINARY, + ~rawbytes~ BLOB, + PRIMARY KEY (~id~) +)ENGINE=InnoDB DEFAULT CHARSET=latin1; /*!40101 SET character_set_client = @saved_cs_client */; -- @@ -44,7 +79,7 @@ CREATE TABLE 'Test_Table' (~id~ int(11) NOT NULL AUTO_INCREMENT,~email~ char(60) LOCK TABLES ~Test_Table~ WRITE; /*!40000 ALTER TABLE ~Test_Table~ DISABLE KEYS */; -INSERT INTO ~Test_Table~ (~id~, ~email~, ~name~) VALUES (1,NULL,'Test Name 1'),(2,'test2@test.de','Test Name 2'); +INSERT INTO ~Test_Table~ (~id~, ~email~, ~given_name~, ~surname~, ~int8~, ~NullInt8~, ~uint8~, ~NullUint8~, ~int16~, ~NullInt16~, ~uint16~, ~NullUint16~, ~int32~, ~NullInt32~, ~uint32~, ~NullUint32~, ~int64~, ~NullInt64~, ~uint64~, ~float32~, ~NullFloat32~, ~float64~, ~NullFloat64~, ~bool~, ~NullBool~, ~time~, ~NullTime~, ~varbinary~, ~rawbytes~) VALUES (1,'test1@test.de','Test','Name 1',1,NULL,1,NULL,1,NULL,1,NULL,1,NULL,1,NULL,1,NULL,1,1.000000,NULL,1.000000,NULL,1,NULL,'1970-01-01 00:00:00',NULL,NULL,NULL),(2,'test2@test.de',NULL,'Test Name 2',2,NULL,2,NULL,2,NULL,2,NULL,2,NULL,2,NULL,2,NULL,2,2.000000,NULL,2.000000,NULL,1,NULL,'1970-01-01 00:00:00',NULL,NULL,NULL); /*!40000 ALTER TABLE ~Test_Table~ ENABLE KEYS */; UNLOCK TABLES; /*!40103 SET TIME_ZONE=@OLD_TIME_ZONE */; @@ -60,31 +95,110 @@ UNLOCK TABLES; ` func mockColumnRows() *sqlmock.Rows { - var enum struct{} col1 := sqlmock.NewColumn("Field").OfType("VARCHAR", "").Nullable(true) col2 := sqlmock.NewColumn("Type").OfType("TEXT", "").Nullable(true) col3 := sqlmock.NewColumn("Null").OfType("VARCHAR", "").Nullable(true) - col4 := sqlmock.NewColumn("Key").OfType("ENUM", &enum).Nullable(true) + col4 := sqlmock.NewColumn("Key").OfType("ENUM", "").Nullable(true) col5 := sqlmock.NewColumn("Default").OfType("TEXT", "").Nullable(true) col6 := sqlmock.NewColumn("Extra").OfType("VARCHAR", "").Nullable(true) return sqlmock.NewRowsWithColumnDefinition(col1, col2, col3, col4, col5, col6). - AddRow("id", "int(11)", false, nil, 0, ""). - AddRow("email", "varchar(255)", true, nil, nil, ""). - AddRow("name", "varchar(255)", true, nil, nil, ""). - AddRow("hash", "varchar(255)", true, nil, nil, "VIRTUAL GENERATED") + AddRow("id", "int(11)", "NO", "PRI", nil, "auto_increment"). + AddRow("email", "varchar(255)", "NO", "", nil, ""). + AddRow("given_name", "varchar(127)", "NO", "", "", ""). + AddRow("surname", "varchar(127)", "NO", "", "", ""). + AddRow("name", "varchar(255)", "YES", "", nil, "VIRTUAL GENERATED"). + AddRow("int8", "TINYINT", "NO", "", nil, ""). + AddRow("NullInt8", "TINYINT", "YES", "", nil, ""). + AddRow("uint8", "TINYINT UNSIGNED", "NO", "", nil, ""). + AddRow("NullUint8", "TINYINT UNSIGNED", "YES", "", nil, ""). + AddRow("int16", "SMALLINT", "NO", "", nil, ""). + AddRow("NullInt16", "SMALLINT", "YES", "", nil, ""). + AddRow("uint16", "SMALLINT UNSIGNED", "NO", "", nil, ""). + AddRow("NullUint16", "SMALLINT UNSIGNED", "YES", "", nil, ""). + AddRow("int32", "INT(11)", "NO", "", nil, ""). + AddRow("NullInt32", "INT(11)", "YES", "", nil, ""). + AddRow("uint32", "INT(11) UNSIGNED", "NO", "", nil, ""). + AddRow("NullUint32", "INT(11) UNSIGNED", "YES", "", nil, ""). + AddRow("int64", "BIGINT", "NO", "", nil, ""). + AddRow("NullInt64", "BIGINT", "YES", "", nil, ""). + AddRow("uint64", "BIGINT UNSIGNED", "NO", "", nil, ""). + AddRow("float32", "FLOAT", "NO", "", nil, ""). + AddRow("NullFloat32", "FLOAT", "YES", "", nil, ""). + AddRow("float64", "DOUBLE", "NO", "", nil, ""). + AddRow("NullFloat64", "DOUBLE", "YES", "", nil, ""). + AddRow("bool", "BOOL", "NO", "", nil, ""). + AddRow("NullBool", "BOOL", "YES", "", nil, ""). + AddRow("time", "TIME", "NO", "", nil, ""). + AddRow("NullTime", "TIME", "YES", "", nil, ""). + AddRow("varbinary", "VARBINARY", "YES", "", nil, ""). + AddRow("rawbytes", "BLOB", "YES", "", nil, "") } func c(name string, v interface{}) *sqlmock.Column { var t string - switch reflect.ValueOf(v).Kind() { - case reflect.String: + var nullable bool + switch v.(type) { + case string: t = "VARCHAR" - case reflect.Int: - t = "INT" - case reflect.Bool: + case sql.NullString: + nullable = true + t = "VARCHAR" + case int8: + t = "TINYINT" + case int16: + t = "SMALLINT" + case sql.NullInt16: + nullable = true + t = "SMALLINT" + case int32: + t = "INT(11)" + case sql.NullInt32: + nullable = true + t = "INT(11)" + case int64: + t = "BIGINT" + case sql.NullInt64: + nullable = true + t = "BIGINT" + case int: + t = "BIGINT" + case uint8: + t = "TINYINT UNSIGNED" + case uint16: + t = "SMALLINT UNSIGNED" + case uint32: + t = "INT UNSIGNED" + case uint64: + t = "BIGINT UNSIGNED" + case uint: + t = "BIGINT UNSIGNED" + case float32: + t = "FLOAT" + case float64: + t = "DOUBLE" + case sql.NullFloat64: + nullable = true + t = "DOUBLE" + case bool: + t = "BOOL" + case sql.NullBool: + nullable = true t = "BOOL" + case time.Time: + t = "TIME" + case sql.NullTime: + nullable = true + t = "TIME" + case []byte: + nullable = true + t = "VARBINARY" + case sql.RawBytes: + nullable = true + t = "BLOB" + default: + panic(fmt.Errorf("unknown value type: %T", v)) } - return sqlmock.NewColumn(name).OfType(t, v).Nullable(true) + return sqlmock.NewColumn(name).OfType(t, v).Nullable(nullable) } func RunDump(t testing.TB, data *mysqldump.Data) { @@ -102,11 +216,133 @@ func RunDump(t testing.TB, data *mysqldump.Data) { AddRow("test_version") createTableRows := sqlmock.NewRowsWithColumnDefinition(c("Table", ""), c("Create Table", "")). - AddRow("Test_Table", "CREATE TABLE 'Test_Table' (`id` int(11) NOT NULL AUTO_INCREMENT,`email` char(60) DEFAULT NULL, `name` char(60), PRIMARY KEY (`id`))ENGINE=InnoDB DEFAULT CHARSET=latin1") - - createTableValueRows := sqlmock.NewRowsWithColumnDefinition(c("id", 0), c("email", ""), c("name", "")). - AddRow(1, nil, "Test Name 1"). - AddRow(2, "test2@test.de", "Test Name 2") + AddRow("Test_Table", strings.ReplaceAll(`CREATE TABLE 'Test_Table' ( + ~id~ int(11) NOT NULL AUTO_INCREMENT, + ~email~ varchar(255) NOT NULL, + ~given_name~ varchar(127) NOT NULL DEFAULT '', + ~surname~ varchar(127) NOT NULL DEFAULT '', + ~name~ varchar(255) GENERATED ALWAYS AS (CONCAT(given_name,' ',surname)), + ~int8~ TINYINT NOT NULL, + ~NullInt8~ TINYINT, + ~uint8~ TINYINT UNSIGNED NOT NULL, + ~NullUint8~ TINYINT UNSIGNED, + ~int16~ SMALLINT NOT NULL, + ~NullInt16~ SMALLINT, + ~uint16~ SMALLINT UNSIGNED NOT NULL, + ~NullUint16~ SMALLINT UNSIGNED, + ~int32~ INT(11) NOT NULL, + ~NullInt32~ INT(11), + ~uint32~ INT(11) UNSIGNED NOT NULL, + ~NullUint32~ INT(11) UNSIGNED, + ~int64~ BIGINT NOT NULL, + ~NullInt64~ BIGINT, + ~uint64~ BIGINT UNSIGNED NOT NULL, + ~float32~ FLOAT NOT NULL, + ~NullFloat32~ FLOAT, + ~float64~ DOUBLE NOT NULL, + ~NullFloat64~ DOUBLE, + ~bool~ TINYINT(1) NOT NULL, + ~NullBool~ TINYINT(1), + ~time~ TIME NOT NULL, + ~NullTime~ TIME, + ~varbinary~ VARBINARY, + ~rawbytes~ BLOB, + PRIMARY KEY (~id~) +)ENGINE=InnoDB DEFAULT CHARSET=latin1`, "~", "`")) + + createTableValueRows := sqlmock.NewRowsWithColumnDefinition( + c("id", int32(0)), + c("email", ""), + c("given_name", sql.NullString{}), + c("surname", sql.NullString{}), + c("int8", int8(0)), + c("NullInt8", sql.NullInt16{}), + c("uint8", uint8(0)), + c("NullUint8", sql.NullInt16{}), + c("int16", int16(0)), + c("NullInt16", sql.NullInt16{}), + c("uint16", uint16(0)), + c("NullUint16", sql.NullInt32{}), + c("int32", int32(0)), + c("NullInt32", sql.NullInt32{}), + c("uint32", uint32(0)), + c("NullUint32", sql.NullInt64{}), + c("int64", int64(0)), + c("NullInt64", sql.NullInt64{}), + c("uint64", uint64(0)), + c("float32", float32(0)), + c("NullFloat32", sql.NullFloat64{}), + c("float64", float64(0)), + c("NullFloat64", sql.NullFloat64{}), + c("bool", false), + c("NullBool", sql.NullBool{}), + c("time", time.Time{}), + c("NullTime", sql.NullTime{}), + c("varbinary", []byte{}), + c("rawbytes", sql.RawBytes{}), + ). + AddRow( + int32(1), + "test1@test.de", + "Test", + "Name 1", + int8(1), + sql.NullInt16{}, + uint8(1), + sql.NullInt16{}, + int16(1), + sql.NullInt16{}, + uint16(1), + sql.NullInt32{}, + int32(1), + sql.NullInt32{}, + uint32(1), + sql.NullInt64{}, + int64(1), + sql.NullInt64{}, + uint64(1), + float32(1), + sql.NullFloat64{}, + float64(1), + sql.NullFloat64{}, + true, + sql.NullBool{}, + time.Unix(0, 0), + sql.NullTime{}, + []byte{}, + sql.RawBytes{}, + ). + AddRow( + int32(2), + "test2@test.de", + nil, + "Test Name 2", + int8(2), + sql.NullInt16{}, + uint8(2), + sql.NullInt16{}, + int16(2), + sql.NullInt16{}, + uint16(2), + sql.NullInt32{}, + int32(2), + sql.NullInt32{}, + uint32(2), + sql.NullInt64{}, + int64(2), + sql.NullInt64{}, + uint64(2), + float32(2), + sql.NullFloat64{}, + float64(2), + sql.NullFloat64{}, + true, + sql.NullBool{}, + time.Unix(0, 0), + sql.NullTime{}, + []byte{}, + sql.RawBytes{}, + ) mock.ExpectBegin() mock.ExpectQuery(`^SELECT version\(\)$`).WillReturnRows(serverVersionRows) @@ -115,9 +351,12 @@ func RunDump(t testing.TB, data *mysqldump.Data) { mock.ExpectQuery("^SHOW CREATE TABLE `Test_Table`$").WillReturnRows(createTableRows) mock.ExpectQuery("^SHOW COLUMNS FROM `Test_Table`$").WillReturnRows(showColumnsRows) mock.ExpectQuery("^SELECT (.+) FROM `Test_Table`$").WillReturnRows(createTableValueRows) + mock.ExpectExec("UNLOCK TABLES") mock.ExpectRollback() assert.NoError(t, data.Dump(), "an error was not expected when dumping a stub database connection") + + assert.NoError(t, mock.ExpectationsWereMet(), "there were unfulfilled expections") } func TestDumpOk(t *testing.T) { @@ -155,11 +394,133 @@ func TestNoLockOk(t *testing.T) { AddRow("test_version") createTableRows := sqlmock.NewRowsWithColumnDefinition(c("Table", ""), c("Create Table", "")). - AddRow("Test_Table", "CREATE TABLE 'Test_Table' (`id` int(11) NOT NULL AUTO_INCREMENT,`email` char(60) DEFAULT NULL, `name` char(60), PRIMARY KEY (`id`))ENGINE=InnoDB DEFAULT CHARSET=latin1") - - createTableValueRows := sqlmock.NewRowsWithColumnDefinition(c("id", 0), c("email", ""), c("name", "")). - AddRow(1, nil, "Test Name 1"). - AddRow(2, "test2@test.de", "Test Name 2") + AddRow("Test_Table", strings.ReplaceAll(`CREATE TABLE 'Test_Table' ( + ~id~ int(11) NOT NULL AUTO_INCREMENT, + ~email~ varchar(255) NOT NULL, + ~given_name~ varchar(127) NOT NULL DEFAULT '', + ~surname~ varchar(127) NOT NULL DEFAULT '', + ~name~ varchar(255) GENERATED ALWAYS AS (CONCAT(given_name,' ',surname)), + ~int8~ TINYINT NOT NULL, + ~NullInt8~ TINYINT, + ~uint8~ TINYINT UNSIGNED NOT NULL, + ~NullUint8~ TINYINT UNSIGNED, + ~int16~ SMALLINT NOT NULL, + ~NullInt16~ SMALLINT, + ~uint16~ SMALLINT UNSIGNED NOT NULL, + ~NullUint16~ SMALLINT UNSIGNED, + ~int32~ INT(11) NOT NULL, + ~NullInt32~ INT(11), + ~uint32~ INT(11) UNSIGNED NOT NULL, + ~NullUint32~ INT(11) UNSIGNED, + ~int64~ BIGINT NOT NULL, + ~NullInt64~ BIGINT, + ~uint64~ BIGINT UNSIGNED NOT NULL, + ~float32~ FLOAT NOT NULL, + ~NullFloat32~ FLOAT, + ~float64~ DOUBLE NOT NULL, + ~NullFloat64~ DOUBLE, + ~bool~ TINYINT(1) NOT NULL, + ~NullBool~ TINYINT(1), + ~time~ TIME NOT NULL, + ~NullTime~ TIME, + ~varbinary~ VARBINARY, + ~rawbytes~ BLOB, + PRIMARY KEY (~id~) +)ENGINE=InnoDB DEFAULT CHARSET=latin1`, "~", "`")) + + createTableValueRows := sqlmock.NewRowsWithColumnDefinition( + c("id", int32(0)), + c("email", ""), + c("given_name", sql.NullString{}), + c("surname", sql.NullString{}), + c("int8", int8(0)), + c("NullInt8", sql.NullInt16{}), + c("uint8", uint8(0)), + c("NullUint8", sql.NullInt16{}), + c("int16", int16(0)), + c("NullInt16", sql.NullInt16{}), + c("uint16", uint16(0)), + c("NullUint16", sql.NullInt32{}), + c("int32", int32(0)), + c("NullInt32", sql.NullInt32{}), + c("uint32", uint32(0)), + c("NullUint32", sql.NullInt64{}), + c("int64", int64(0)), + c("NullInt64", sql.NullInt64{}), + c("uint64", uint64(0)), + c("float32", float32(0)), + c("NullFloat32", sql.NullFloat64{}), + c("float64", float64(0)), + c("NullFloat64", sql.NullFloat64{}), + c("bool", false), + c("NullBool", sql.NullBool{}), + c("time", time.Time{}), + c("NullTime", sql.NullTime{}), + c("varbinary", []byte{}), + c("rawbytes", sql.RawBytes{}), + ). + AddRow( + int32(1), + "test1@test.de", + "Test", + "Name 1", + int8(1), + sql.NullInt16{}, + uint8(1), + sql.NullInt16{}, + int16(1), + sql.NullInt16{}, + uint16(1), + sql.NullInt32{}, + int32(1), + sql.NullInt32{}, + uint32(1), + sql.NullInt64{}, + int64(1), + sql.NullInt64{}, + uint64(1), + float32(1), + sql.NullFloat64{}, + float64(1), + sql.NullFloat64{}, + true, + sql.NullBool{}, + time.Unix(0, 0), + sql.NullTime{}, + []byte{}, + sql.RawBytes{}, + ). + AddRow( + int32(2), + "test2@test.de", + nil, + "Test Name 2", + int8(2), + sql.NullInt16{}, + uint8(2), + sql.NullInt16{}, + int16(2), + sql.NullInt16{}, + uint16(2), + sql.NullInt32{}, + int32(2), + sql.NullInt32{}, + uint32(2), + sql.NullInt64{}, + int64(2), + sql.NullInt64{}, + uint64(2), + float32(2), + sql.NullFloat64{}, + float64(2), + sql.NullFloat64{}, + true, + sql.NullBool{}, + time.Unix(0, 0), + sql.NullTime{}, + []byte{}, + sql.RawBytes{}, + ) mock.ExpectBegin() mock.ExpectQuery(`^SELECT version\(\)$`).WillReturnRows(serverVersionRows) @@ -174,11 +535,13 @@ func TestNoLockOk(t *testing.T) { result := strings.Replace(strings.Split(buf.String(), "-- Dump completed")[0], "`", "~", -1) assert.Equal(t, expected, result) + + assert.NoError(t, mock.ExpectationsWereMet(), "there were unfulfilled expections") } func BenchmarkDump(b *testing.B) { data := &mysqldump.Data{ - Out: ioutil.Discard, + Out: io.Discard, LockTables: true, } for i := 0; i < b.N; i++ { From 2301e09df4db963cafa45f5cd68fc5d0a422afcf Mon Sep 17 00:00:00 2001 From: Michael Chinn Date: Tue, 29 Aug 2023 20:10:39 -0400 Subject: [PATCH 2/2] workaround github.com/go-sql-driver/mysql/pull/1424 not released --- dump.go | 19 ++++++++++++++++++- mysqldump_test.go | 20 ++++++++++---------- 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/dump.go b/dump.go index 5109f50..c0e1122 100644 --- a/dump.go +++ b/dump.go @@ -383,11 +383,28 @@ func (table *table) Init() error { table.values = make([]interface{}, len(tt)) for i, tp := range tt { - table.values[i] = reflect.New(tp.ScanType()).Interface() + table.values[i] = reflect.New(reflectColumnType(tp)).Interface() } return nil } +func reflectColumnType(tp *sql.ColumnType) reflect.Type { + // workaround https://github.com/go-sql-driver/mysql/pull/1424 till it's released + nullable, _ := tp.Nullable() + switch tp.DatabaseTypeName() { + case "TINYBLOB", "MEDIUMBLOB", "LONGBLOB", "BLOB", + "VARBINARY", "BINARY", "BIT", "GEOMETRY": + return reflect.TypeOf([]byte{}) + case "TINYTEXT", "MEDIUMTEXT", "LONGTEXT", "TEXT", + "VARCHAR", "CHAR", "DECIMAL", "ENUM", "SET", "JSON", "TIME": + if nullable { + return reflect.TypeOf(sql.NullString{}) + } + return reflect.TypeOf("") + } + return tp.ScanType() +} + func (table *table) Next() bool { if table.rows == nil { if err := table.Init(); err != nil { diff --git a/mysqldump_test.go b/mysqldump_test.go index ad62f59..4d6fce7 100644 --- a/mysqldump_test.go +++ b/mysqldump_test.go @@ -65,8 +65,8 @@ CREATE TABLE 'Test_Table' ( ~NullFloat64~ DOUBLE, ~bool~ TINYINT(1) NOT NULL, ~NullBool~ TINYINT(1), - ~time~ TIME NOT NULL, - ~NullTime~ TIME, + ~time~ DATETIME NOT NULL, + ~NullTime~ DATETIME, ~varbinary~ VARBINARY, ~rawbytes~ BLOB, PRIMARY KEY (~id~) @@ -128,8 +128,8 @@ func mockColumnRows() *sqlmock.Rows { AddRow("NullFloat64", "DOUBLE", "YES", "", nil, ""). AddRow("bool", "BOOL", "NO", "", nil, ""). AddRow("NullBool", "BOOL", "YES", "", nil, ""). - AddRow("time", "TIME", "NO", "", nil, ""). - AddRow("NullTime", "TIME", "YES", "", nil, ""). + AddRow("time", "DATETIME", "NO", "", nil, ""). + AddRow("NullTime", "DATETIME", "YES", "", nil, ""). AddRow("varbinary", "VARBINARY", "YES", "", nil, ""). AddRow("rawbytes", "BLOB", "YES", "", nil, "") } @@ -185,10 +185,10 @@ func c(name string, v interface{}) *sqlmock.Column { nullable = true t = "BOOL" case time.Time: - t = "TIME" + t = "DATETIME" case sql.NullTime: nullable = true - t = "TIME" + t = "DATETIME" case []byte: nullable = true t = "VARBINARY" @@ -243,8 +243,8 @@ func RunDump(t testing.TB, data *mysqldump.Data) { ~NullFloat64~ DOUBLE, ~bool~ TINYINT(1) NOT NULL, ~NullBool~ TINYINT(1), - ~time~ TIME NOT NULL, - ~NullTime~ TIME, + ~time~ DATETIME NOT NULL, + ~NullTime~ DATETIME, ~varbinary~ VARBINARY, ~rawbytes~ BLOB, PRIMARY KEY (~id~) @@ -421,8 +421,8 @@ func TestNoLockOk(t *testing.T) { ~NullFloat64~ DOUBLE, ~bool~ TINYINT(1) NOT NULL, ~NullBool~ TINYINT(1), - ~time~ TIME NOT NULL, - ~NullTime~ TIME, + ~time~ DATETIME NOT NULL, + ~NullTime~ DATETIME, ~varbinary~ VARBINARY, ~rawbytes~ BLOB, PRIMARY KEY (~id~)