diff --git a/go.mod b/go.mod index 14bef5708..8eca24582 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/stretchr/testify v1.9.0 github.com/testcontainers/testcontainers-go v0.34.0 golang.org/x/net v0.24.0 + golang.org/x/sync v0.8.0 golang.org/x/term v0.19.0 golang.org/x/text v0.14.0 ) diff --git a/go.sum b/go.sum index f447925ba..c800ed549 100644 --- a/go.sum +++ b/go.sum @@ -199,6 +199,8 @@ golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/go/base/context.go b/go/base/context.go index 59227ea2d..2518ecf4e 100644 --- a/go/base/context.go +++ b/go/base/context.go @@ -103,6 +103,11 @@ type MigrationContext struct { AzureMySQL bool AttemptInstantDDL bool + // SkipPortValidation allows skipping the port validation in `ValidateConnection` + // This is useful when connecting to a MySQL instance where the external port + // may not match the internal port. + SkipPortValidation bool + config ContextConfig configMutex *sync.Mutex ConfigFile string diff --git a/go/base/utils.go b/go/base/utils.go index 725bb2279..89f6d315f 100644 --- a/go/base/utils.go +++ b/go/base/utils.go @@ -63,18 +63,27 @@ func StringContainsAll(s string, substrings ...string) bool { func ValidateConnection(db *gosql.DB, connectionConfig *mysql.ConnectionConfig, migrationContext *MigrationContext, name string) (string, error) { versionQuery := `select @@global.version` - var port, extraPort int + var version string if err := db.QueryRow(versionQuery).Scan(&version); err != nil { return "", err } + + if migrationContext.SkipPortValidation { + return version, nil + } + + var extraPort int + extraPortQuery := `select @@global.extra_port` if err := db.QueryRow(extraPortQuery).Scan(&extraPort); err != nil { //nolint:staticcheck // swallow this error. not all servers support extra_port } + // AliyunRDS set users port to "NULL", replace it by gh-ost param // GCP set users port to "NULL", replace it by gh-ost param // Azure MySQL set users port to a different value by design, replace it by gh-ost para + var port int if migrationContext.AliyunRDS || migrationContext.GoogleCloudPlatform || migrationContext.AzureMySQL { port = connectionConfig.Key.Port } else { diff --git a/go/logic/applier_test.go b/go/logic/applier_test.go index 9888f3176..f53e65ffb 100644 --- a/go/logic/applier_test.go +++ b/go/logic/applier_test.go @@ -10,7 +10,6 @@ import ( gosql "database/sql" "strings" "testing" - "time" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" @@ -20,7 +19,6 @@ import ( "github.com/github/gh-ost/go/base" "github.com/github/gh-ost/go/binlog" - "github.com/github/gh-ost/go/mysql" "github.com/github/gh-ost/go/sql" ) @@ -186,6 +184,7 @@ func TestApplierBuildDMLEventQuery(t *testing.T) { func TestApplierInstantDDL(t *testing.T) { migrationContext := base.NewMigrationContext() migrationContext.DatabaseName = "test" + migrationContext.SkipPortValidation = true migrationContext.OriginalTableName = "mytable" migrationContext.AlterStatementOptions = "ADD INDEX (foo)" applier := NewApplier(migrationContext) @@ -200,38 +199,16 @@ type ApplierTestSuite struct { suite.Suite mysqlContainer testcontainers.Container -} - -func (suite *ApplierTestSuite) getConnectionConfig(ctx context.Context) (*mysql.ConnectionConfig, error) { - host, err := suite.mysqlContainer.ContainerIP(ctx) - if err != nil { - return nil, err - } - - config := mysql.NewConnectionConfig() - config.Key.Hostname = host - config.Key.Port = 3306 - config.User = "root" - config.Password = "root-password" - - return config, nil -} - -func (suite *ApplierTestSuite) getDb(ctx context.Context) (*gosql.DB, error) { - host, err := suite.mysqlContainer.ContainerIP(ctx) - if err != nil { - return nil, err - } - - return gosql.Open("mysql", "root:root-password@tcp("+host+":3306)/test") + db *gosql.DB } func (suite *ApplierTestSuite) SetupSuite() { ctx := context.Background() req := testcontainers.ContainerRequest{ - Image: "mysql:8.0", - Env: map[string]string{"MYSQL_ROOT_PASSWORD": "root-password"}, - WaitingFor: wait.ForLog("port: 3306 MySQL Community Server - GPL"), + Image: "mysql:8.0.40", + Env: map[string]string{"MYSQL_ROOT_PASSWORD": "root-password"}, + ExposedPorts: []string{"3306/tcp"}, + WaitingFor: wait.ForListeningPort("3306/tcp"), } mysqlContainer, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ @@ -241,43 +218,52 @@ func (suite *ApplierTestSuite) SetupSuite() { suite.Require().NoError(err) suite.mysqlContainer = mysqlContainer + + dsn, err := GetDSN(ctx, mysqlContainer) + suite.Require().NoError(err) + + db, err := gosql.Open("mysql", dsn) + suite.Require().NoError(err) + + suite.db = db } func (suite *ApplierTestSuite) TeardownSuite() { ctx := context.Background() - suite.Require().NoError(suite.mysqlContainer.Terminate(ctx)) + suite.Assert().NoError(suite.db.Close()) + suite.Assert().NoError(suite.mysqlContainer.Terminate(ctx)) } func (suite *ApplierTestSuite) SetupTest() { ctx := context.Background() - rc, _, err := suite.mysqlContainer.Exec(ctx, []string{"mysql", "-uroot", "-proot-password", "-e", "CREATE DATABASE test;"}) - suite.Require().NoError(err) - suite.Require().Equalf(0, rc, "failed to created database: expected exit code 0, got %d", rc) - - rc, _, err = suite.mysqlContainer.Exec(ctx, []string{"mysql", "-uroot", "-proot-password", "-e", "CREATE TABLE test.testing (id INT, item_id INT, PRIMARY KEY (id));"}) + _, err := suite.db.ExecContext(ctx, "CREATE DATABASE test") suite.Require().NoError(err) - suite.Require().Equalf(0, rc, "failed to created table: expected exit code 0, got %d", rc) } func (suite *ApplierTestSuite) TearDownTest() { ctx := context.Background() - rc, _, err := suite.mysqlContainer.Exec(ctx, []string{"mysql", "-uroot", "-proot-password", "-e", "DROP DATABASE test;"}) + _, err := suite.db.ExecContext(ctx, "DROP DATABASE test") suite.Require().NoError(err) - suite.Require().Equalf(0, rc, "failed to created database: expected exit code 0, got %d", rc) } func (suite *ApplierTestSuite) TestInitDBConnections() { ctx := context.Background() - connectionConfig, err := suite.getConnectionConfig(ctx) + var err error + + _, err = suite.db.ExecContext(ctx, "CREATE TABLE test.testing (id INT, item_id INT);") + suite.Require().NoError(err) + + connectionConfig, err := GetConnectionConfig(ctx, suite.mysqlContainer) suite.Require().NoError(err) migrationContext := base.NewMigrationContext() migrationContext.ApplierConnectionConfig = connectionConfig migrationContext.DatabaseName = "test" + migrationContext.SkipPortValidation = true migrationContext.OriginalTableName = "testing" migrationContext.SetConnectionConfig("innodb") @@ -297,12 +283,21 @@ func (suite *ApplierTestSuite) TestInitDBConnections() { func (suite *ApplierTestSuite) TestApplyDMLEventQueries() { ctx := context.Background() - connectionConfig, err := suite.getConnectionConfig(ctx) + var err error + + _, err = suite.db.ExecContext(ctx, "CREATE TABLE test.testing (id INT, item_id INT);") + suite.Require().NoError(err) + + _, err = suite.db.ExecContext(ctx, "CREATE TABLE test._testing_gho (id INT, item_id INT);") + suite.Require().NoError(err) + + connectionConfig, err := GetConnectionConfig(ctx, suite.mysqlContainer) suite.Require().NoError(err) migrationContext := base.NewMigrationContext() migrationContext.ApplierConnectionConfig = connectionConfig migrationContext.DatabaseName = "test" + migrationContext.SkipPortValidation = true migrationContext.OriginalTableName = "testing" migrationContext.SetConnectionConfig("innodb") @@ -321,10 +316,6 @@ func (suite *ApplierTestSuite) TestApplyDMLEventQueries() { err = applier.InitDBConnections() suite.Require().NoError(err) - rc, _, err := suite.mysqlContainer.Exec(ctx, []string{"mysql", "-uroot", "-proot-password", "-e", "CREATE TABLE test._testing_gho (id INT, item_id INT);"}) - suite.Require().NoError(err) - suite.Require().Equalf(0, rc, "failed to created table: expected exit code 0, got %d", rc) - dmlEvents := []*binlog.BinlogDMLEvent{ { DatabaseName: "test", @@ -337,11 +328,7 @@ func (suite *ApplierTestSuite) TestApplyDMLEventQueries() { suite.Require().NoError(err) // Check that the row was inserted - db, err := suite.getDb(ctx) - suite.Require().NoError(err) - defer db.Close() - - rows, err := db.Query("SELECT * FROM test._testing_gho") + rows, err := suite.db.Query("SELECT * FROM test._testing_gho") suite.Require().NoError(err) defer rows.Close() @@ -364,12 +351,18 @@ func (suite *ApplierTestSuite) TestApplyDMLEventQueries() { func (suite *ApplierTestSuite) TestValidateOrDropExistingTables() { ctx := context.Background() - connectionConfig, err := suite.getConnectionConfig(ctx) + var err error + + _, err = suite.db.ExecContext(ctx, "CREATE TABLE test.testing (id INT, item_id INT);") + suite.Require().NoError(err) + + connectionConfig, err := GetConnectionConfig(ctx, suite.mysqlContainer) suite.Require().NoError(err) migrationContext := base.NewMigrationContext() migrationContext.ApplierConnectionConfig = connectionConfig migrationContext.DatabaseName = "test" + migrationContext.SkipPortValidation = true migrationContext.OriginalTableName = "testing" migrationContext.SetConnectionConfig("innodb") @@ -387,41 +380,30 @@ func (suite *ApplierTestSuite) TestValidateOrDropExistingTables() { suite.Require().NoError(err) } -func (suite *ApplierTestSuite) TestApplyIterationInsertQuery() { +func (suite *ApplierTestSuite) TestValidateOrDropExistingTablesWithGhostTableExisting() { ctx := context.Background() - connectionConfig, err := suite.getConnectionConfig(ctx) + var err error + + _, err = suite.db.ExecContext(ctx, "CREATE TABLE test.testing (id INT, item_id INT);") + suite.Require().NoError(err) + + _, err = suite.db.ExecContext(ctx, "CREATE TABLE test._testing_gho (id INT, item_id INT);") + suite.Require().NoError(err) + + connectionConfig, err := GetConnectionConfig(ctx, suite.mysqlContainer) suite.Require().NoError(err) migrationContext := base.NewMigrationContext() migrationContext.ApplierConnectionConfig = connectionConfig migrationContext.DatabaseName = "test" + migrationContext.SkipPortValidation = true migrationContext.OriginalTableName = "testing" - migrationContext.ChunkSize = 10 migrationContext.SetConnectionConfig("innodb") - db, err := suite.getDb(ctx) - suite.Require().NoError(err) - defer db.Close() - - _, err = db.Exec("CREATE TABLE test._testing_gho (id INT, item_id INT, PRIMARY KEY (id))") - suite.Require().NoError(err) - - // Insert some test values - for i := 1; i <= 10; i++ { - _, err = db.Exec("INSERT INTO test.testing (id, item_id) VALUES (?, ?)", i, i) - suite.Require().NoError(err) - } - + migrationContext.OriginalTableColumns = sql.NewColumnList([]string{"id", "item_id"}) migrationContext.SharedColumns = sql.NewColumnList([]string{"id", "item_id"}) migrationContext.MappedSharedColumns = sql.NewColumnList([]string{"id", "item_id"}) - migrationContext.UniqueKey = &sql.UniqueKey{ - Name: "PRIMARY", - Columns: *sql.NewColumnList([]string{"id"}), - } - - migrationContext.MigrationIterationRangeMinValues = sql.ToColumnValues([]interface{}{1}) - migrationContext.MigrationIterationRangeMaxValues = sql.ToColumnValues([]interface{}{10}) applier := NewApplier(migrationContext) defer applier.Teardown() @@ -429,65 +411,74 @@ func (suite *ApplierTestSuite) TestApplyIterationInsertQuery() { err = applier.InitDBConnections() suite.Require().NoError(err) - chunkSize, rowsAffected, duration, err := applier.ApplyIterationInsertQuery() - suite.Require().NoError(err) - - suite.Require().Equal(migrationContext.ChunkSize, chunkSize) - suite.Require().Equal(int64(10), rowsAffected) - suite.Require().Greater(duration, time.Duration(0)) + err = applier.ValidateOrDropExistingTables() + suite.Require().Error(err) + suite.Require().EqualError(err, "Table `_testing_gho` already exists. Panicking. Use --initially-drop-ghost-table to force dropping it, though I really prefer that you drop it or rename it away") +} - // Check that the rows were inserted - rows, err := db.Query("SELECT * FROM test._testing_gho") - suite.Require().NoError(err) - defer rows.Close() +func (suite *ApplierTestSuite) TestValidateOrDropExistingTablesWithGhostTableExistingAndInitiallyDropGhostTableSet() { + ctx := context.Background() - var count, id, item_id int - for rows.Next() { - err = rows.Scan(&id, &item_id) - suite.Require().NoError(err) - count += 1 - } - suite.Require().NoError(rows.Err()) + var err error - suite.Require().Equal(10, count) -} + _, err = suite.db.ExecContext(ctx, "CREATE TABLE test.testing (id INT, item_id INT);") + suite.Require().NoError(err) -func (suite *ApplierTestSuite) TestApplyIterationInsertQueryFailsFastWhenSelectingLockedRows() { - ctx := context.Background() + _, err = suite.db.ExecContext(ctx, "CREATE TABLE test._testing_gho (id INT, item_id INT);") + suite.Require().NoError(err) - connectionConfig, err := suite.getConnectionConfig(ctx) + connectionConfig, err := GetConnectionConfig(ctx, suite.mysqlContainer) suite.Require().NoError(err) migrationContext := base.NewMigrationContext() migrationContext.ApplierConnectionConfig = connectionConfig migrationContext.DatabaseName = "test" + migrationContext.SkipPortValidation = true migrationContext.OriginalTableName = "testing" - migrationContext.ChunkSize = 10 - migrationContext.TableEngine = "innodb" migrationContext.SetConnectionConfig("innodb") - db, err := suite.getDb(ctx) + migrationContext.InitiallyDropGhostTable = true + + applier := NewApplier(migrationContext) + defer applier.Teardown() + + err = applier.InitDBConnections() suite.Require().NoError(err) - defer db.Close() - _, err = db.Exec("CREATE TABLE test._testing_gho (id INT, item_id INT, PRIMARY KEY (id))") + err = applier.ValidateOrDropExistingTables() suite.Require().NoError(err) - // Insert some test values - for i := 1; i <= 10; i++ { - _, err = db.Exec("INSERT INTO test.testing (id, item_id) VALUES (?, ?)", i, i) - suite.Require().NoError(err) - } + // Check that the ghost table was dropped + var tableName string + //nolint:execinquery + err = suite.db.QueryRow("SHOW TABLES IN test LIKE '_testing_gho'").Scan(&tableName) + suite.Require().Error(err) + suite.Require().Equal(gosql.ErrNoRows, err) +} + +func (suite *ApplierTestSuite) TestCreateGhostTable() { + ctx := context.Background() + + var err error + + _, err = suite.db.ExecContext(ctx, "CREATE TABLE test.testing (id INT, item_id INT);") + suite.Require().NoError(err) + + connectionConfig, err := GetConnectionConfig(ctx, suite.mysqlContainer) + suite.Require().NoError(err) + + migrationContext := base.NewMigrationContext() + migrationContext.ApplierConnectionConfig = connectionConfig + migrationContext.DatabaseName = "test" + migrationContext.SkipPortValidation = true + migrationContext.OriginalTableName = "testing" + migrationContext.SetConnectionConfig("innodb") + migrationContext.OriginalTableColumns = sql.NewColumnList([]string{"id", "item_id"}) migrationContext.SharedColumns = sql.NewColumnList([]string{"id", "item_id"}) migrationContext.MappedSharedColumns = sql.NewColumnList([]string{"id", "item_id"}) - migrationContext.UniqueKey = &sql.UniqueKey{ - Name: "PRIMARY", - Columns: *sql.NewColumnList([]string{"id"}), - } - migrationContext.MigrationIterationRangeMinValues = sql.ToColumnValues([]interface{}{1}) - migrationContext.MigrationIterationRangeMaxValues = sql.ToColumnValues([]interface{}{10}) + migrationContext.InitiallyDropGhostTable = true applier := NewApplier(migrationContext) defer applier.Teardown() @@ -495,30 +486,22 @@ func (suite *ApplierTestSuite) TestApplyIterationInsertQueryFailsFastWhenSelecti err = applier.InitDBConnections() suite.Require().NoError(err) - // Lock one of the rows - tx, err := db.Begin() + err = applier.CreateGhostTable() suite.Require().NoError(err) - defer func() { - suite.Require().NoError(tx.Rollback()) - }() - _, err = tx.Exec("SELECT * FROM test.testing WHERE id = 5 FOR UPDATE") + // Check that the ghost table was created + var tableName string + //nolint:execinquery + err = suite.db.QueryRow("SHOW TABLES IN test LIKE '_testing_gho'").Scan(&tableName) suite.Require().NoError(err) + suite.Require().Equal("_testing_gho", tableName) - chunkSize, rowsAffected, duration, err := applier.ApplyIterationInsertQuery() - suite.Require().Error(err) - suite.Require().EqualError(err, "Error 3572 (HY000): Statement aborted because lock(s) could not be acquired immediately and NOWAIT is set.") - - suite.Require().Equal(migrationContext.ChunkSize, chunkSize) - suite.Require().Equal(int64(0), rowsAffected) - suite.Require().Equal(time.Duration(0), duration) - - // Check that the no rows were inserted - var count int - err = db.QueryRow("SELECT COUNT(*) FROM test._testing_gho").Scan(&count) + // Check that the ghost table has the same columns as the original table + var createDDL string + //nolint:execinquery + err = suite.db.QueryRow("SHOW CREATE TABLE test._testing_gho").Scan(&tableName, &createDDL) suite.Require().NoError(err) - - suite.Require().Equal(0, count) + suite.Require().Equal("CREATE TABLE `_testing_gho` (\n `id` int DEFAULT NULL,\n `item_id` int DEFAULT NULL\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci", createDDL) } func TestApplier(t *testing.T) { diff --git a/go/logic/migrator_test.go b/go/logic/migrator_test.go index 9193de05d..a2a096e69 100644 --- a/go/logic/migrator_test.go +++ b/go/logic/migrator_test.go @@ -6,9 +6,12 @@ package logic import ( + "context" + gosql "database/sql" "errors" "os" "path/filepath" + "runtime" "strings" "sync" "sync/atomic" @@ -16,6 +19,9 @@ import ( "time" "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/wait" "github.com/github/gh-ost/go/base" "github.com/github/gh-ost/go/binlog" @@ -254,3 +260,115 @@ func TestMigratorShouldPrintStatus(t *testing.T) { require.False(t, migrator.shouldPrintStatus(HeuristicPrintStatusRule, 12345, 86400*time.Second)) // test 'else' require.True(t, migrator.shouldPrintStatus(HeuristicPrintStatusRule, 30030, 86400*time.Second)) // test 'else' again } + +type MigratorTestSuite struct { + suite.Suite + + mysqlContainer testcontainers.Container + db *gosql.DB +} + +func (suite *MigratorTestSuite) SetupSuite() { + ctx := context.Background() + req := testcontainers.ContainerRequest{ + Image: "mysql:8.0.40", + Env: map[string]string{"MYSQL_ROOT_PASSWORD": "root-password"}, + ExposedPorts: []string{"3306/tcp"}, + WaitingFor: wait.ForListeningPort("3306/tcp"), + } + + mysqlContainer, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ + ContainerRequest: req, + Started: true, + }) + suite.Require().NoError(err) + + suite.mysqlContainer = mysqlContainer + + dsn, err := GetDSN(ctx, mysqlContainer) + suite.Require().NoError(err) + + db, err := gosql.Open("mysql", dsn) + suite.Require().NoError(err) + + suite.db = db +} + +func (suite *MigratorTestSuite) TeardownSuite() { + ctx := context.Background() + + suite.Assert().NoError(suite.db.Close()) + suite.Assert().NoError(suite.mysqlContainer.Terminate(ctx)) +} + +func (suite *MigratorTestSuite) SetupTest() { + ctx := context.Background() + + _, err := suite.db.ExecContext(ctx, "CREATE DATABASE test") + suite.Require().NoError(err) +} + +func (suite *MigratorTestSuite) TearDownTest() { + ctx := context.Background() + + _, err := suite.db.ExecContext(ctx, "DROP DATABASE test") + suite.Require().NoError(err) +} + +func (suite *MigratorTestSuite) TestFoo() { + ctx := context.Background() + + _, err := suite.db.ExecContext(ctx, "CREATE TABLE test.testing (id INT PRIMARY KEY, name VARCHAR(64))") + suite.Require().NoError(err) + + connectionConfig, err := GetConnectionConfig(ctx, suite.mysqlContainer) + suite.Require().NoError(err) + + migrationContext := base.NewMigrationContext() + migrationContext.AllowedRunningOnMaster = true + migrationContext.ApplierConnectionConfig = connectionConfig + migrationContext.InspectorConnectionConfig = connectionConfig + migrationContext.DatabaseName = "test" + migrationContext.SkipPortValidation = true + migrationContext.OriginalTableName = "testing" + migrationContext.SetConnectionConfig("innodb") + migrationContext.AlterStatementOptions = "ADD COLUMN foobar varchar(255), ENGINE=InnoDB" + migrationContext.ReplicaServerId = 99999 + migrationContext.HeartbeatIntervalMilliseconds = 100 + migrationContext.ThrottleHTTPIntervalMillis = 100 + migrationContext.ThrottleHTTPTimeoutMillis = 1000 + + //nolint:dogsled + _, filename, _, _ := runtime.Caller(0) + migrationContext.ServeSocketFile = filepath.Join(filepath.Dir(filename), "../../tmp/gh-ost.sock") + + migrator := NewMigrator(migrationContext, "0.0.0") + + err = migrator.Migrate() + suite.Require().NoError(err) + + // Verify the new column was added + var tableName, createTableSQL string + //nolint:execinquery + err = suite.db.QueryRow("SHOW CREATE TABLE test.testing").Scan(&tableName, &createTableSQL) + suite.Require().NoError(err) + + suite.Require().Equal("testing", tableName) + suite.Require().Equal("CREATE TABLE `testing` (\n `id` int NOT NULL,\n `name` varchar(64) DEFAULT NULL,\n `foobar` varchar(255) DEFAULT NULL,\n PRIMARY KEY (`id`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci", createTableSQL) + + // Verify the changelog table was claned up + //nolint:execinquery + err = suite.db.QueryRow("SHOW TABLES IN test LIKE '_testing_ghc'").Scan(&tableName) + suite.Require().Error(err) + suite.Require().Equal(gosql.ErrNoRows, err) + + // Verify the old table was renamed + //nolint:execinquery + err = suite.db.QueryRow("SHOW TABLES IN test LIKE '_testing_del'").Scan(&tableName) + suite.Require().NoError(err) + suite.Require().Equal("_testing_del", tableName) +} + +func TestMigrator(t *testing.T) { + suite.Run(t, new(MigratorTestSuite)) +} diff --git a/go/logic/streamer_test.go b/go/logic/streamer_test.go new file mode 100644 index 000000000..301074e3f --- /dev/null +++ b/go/logic/streamer_test.go @@ -0,0 +1,282 @@ +package logic + +import ( + "context" + "database/sql" + gosql "database/sql" + "fmt" + "testing" + "time" + + "github.com/github/gh-ost/go/base" + "github.com/github/gh-ost/go/binlog" + "github.com/stretchr/testify/suite" + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/wait" + + "golang.org/x/sync/errgroup" +) + +type EventsStreamerTestSuite struct { + suite.Suite + + mysqlContainer testcontainers.Container + db *gosql.DB +} + +func (suite *EventsStreamerTestSuite) SetupSuite() { + ctx := context.Background() + req := testcontainers.ContainerRequest{ + Image: "mysql:8.0.40", + Env: map[string]string{"MYSQL_ROOT_PASSWORD": "root-password"}, + ExposedPorts: []string{"3306/tcp"}, + WaitingFor: wait.ForListeningPort("3306/tcp"), + } + + mysqlContainer, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ + ContainerRequest: req, + Started: true, + }) + suite.Require().NoError(err) + + suite.mysqlContainer = mysqlContainer + + dsn, err := GetDSN(ctx, mysqlContainer) + suite.Require().NoError(err) + + db, err := gosql.Open("mysql", dsn) + suite.Require().NoError(err) + + suite.db = db +} + +func (suite *EventsStreamerTestSuite) TeardownSuite() { + ctx := context.Background() + + suite.Assert().NoError(suite.db.Close()) + suite.Assert().NoError(suite.mysqlContainer.Terminate(ctx)) +} + +func (suite *EventsStreamerTestSuite) SetupTest() { + ctx := context.Background() + + _, err := suite.db.ExecContext(ctx, "CREATE DATABASE test") + suite.Require().NoError(err) +} + +func (suite *EventsStreamerTestSuite) TearDownTest() { + ctx := context.Background() + + _, err := suite.db.ExecContext(ctx, "DROP DATABASE test") + suite.Require().NoError(err) +} + +func (suite *EventsStreamerTestSuite) TestStreamEvents() { + ctx := context.Background() + + _, err := suite.db.ExecContext(ctx, "CREATE TABLE test.testing (id INT PRIMARY KEY, name VARCHAR(255))") + suite.Require().NoError(err) + + connectionConfig, err := GetConnectionConfig(ctx, suite.mysqlContainer) + suite.Require().NoError(err) + + migrationContext := base.NewMigrationContext() + migrationContext.ApplierConnectionConfig = connectionConfig + migrationContext.InspectorConnectionConfig = connectionConfig + migrationContext.DatabaseName = "test" + migrationContext.SkipPortValidation = true + migrationContext.ReplicaServerId = 99999 + + migrationContext.SetConnectionConfig("innodb") + + streamer := NewEventsStreamer(migrationContext) + + err = streamer.InitDBConnections() + suite.Require().NoError(err) + defer streamer.Close() + defer streamer.Teardown() + + streamCtx, cancel := context.WithCancel(context.Background()) + + dmlEvents := make([]*binlog.BinlogDMLEvent, 0) + err = streamer.AddListener(false, "test", "testing", func(event *binlog.BinlogDMLEvent) error { + dmlEvents = append(dmlEvents, event) + + // Stop once we've collected three events + if len(dmlEvents) == 3 { + cancel() + } + + return nil + }) + suite.Require().NoError(err) + + group := errgroup.Group{} + group.Go(func() error { + //nolint:contextcheck + return streamer.StreamEvents(func() bool { + return streamCtx.Err() != nil + }) + }) + + group.Go(func() error { + var err error + + _, err = suite.db.ExecContext(ctx, "INSERT INTO test.testing (id, name) VALUES (1, 'foo')") + if err != nil { + return err + } + + _, err = suite.db.ExecContext(ctx, "INSERT INTO test.testing (id, name) VALUES (2, 'bar')") + if err != nil { + return err + } + + _, err = suite.db.ExecContext(ctx, "INSERT INTO test.testing (id, name) VALUES (3, 'baz')") + if err != nil { + return err + } + + // Bug: Need to write fourth event to hit the canStopStreaming function again + _, err = suite.db.ExecContext(ctx, "INSERT INTO test.testing (id, name) VALUES (4, 'qux')") + if err != nil { + return err + } + + return nil + }) + + err = group.Wait() + suite.Require().NoError(err) + + suite.Require().Len(dmlEvents, 3) +} + +func (suite *EventsStreamerTestSuite) TestStreamEventsAutomaticallyReconnects() { + ctx := context.Background() + + _, err := suite.db.ExecContext(ctx, "CREATE TABLE test.testing (id INT PRIMARY KEY, name VARCHAR(255))") + suite.Require().NoError(err) + + connectionConfig, err := GetConnectionConfig(ctx, suite.mysqlContainer) + suite.Require().NoError(err) + + migrationContext := base.NewMigrationContext() + migrationContext.ApplierConnectionConfig = connectionConfig + migrationContext.InspectorConnectionConfig = connectionConfig + migrationContext.DatabaseName = "test" + migrationContext.SkipPortValidation = true + migrationContext.ReplicaServerId = 99999 + + migrationContext.SetConnectionConfig("innodb") + + streamer := NewEventsStreamer(migrationContext) + + err = streamer.InitDBConnections() + suite.Require().NoError(err) + defer streamer.Close() + defer streamer.Teardown() + + streamCtx, cancel := context.WithCancel(context.Background()) + + dmlEvents := make([]*binlog.BinlogDMLEvent, 0) + err = streamer.AddListener(false, "test", "testing", func(event *binlog.BinlogDMLEvent) error { + dmlEvents = append(dmlEvents, event) + + // Stop once we've collected three events + if len(dmlEvents) == 3 { + cancel() + } + + return nil + }) + suite.Require().NoError(err) + + group := errgroup.Group{} + group.Go(func() error { + //nolint:contextcheck + return streamer.StreamEvents(func() bool { + return streamCtx.Err() != nil + }) + }) + + group.Go(func() error { + var err error + + _, err = suite.db.ExecContext(ctx, "INSERT INTO test.testing (id, name) VALUES (1, 'foo')") + if err != nil { + return err + } + + _, err = suite.db.ExecContext(ctx, "INSERT INTO test.testing (id, name) VALUES (2, 'bar')") + if err != nil { + return err + } + + var currentConnectionId int + err = suite.db.QueryRowContext(ctx, "SELECT CONNECTION_ID()").Scan(¤tConnectionId) + if err != nil { + return err + } + + //nolint:execinquery + rows, err := suite.db.Query("SHOW FULL PROCESSLIST") + if err != nil { + return err + } + defer rows.Close() + + connectionIdsToKill := make([]int, 0) + + var id, stateTime int + var user, host, dbName, command, state, info sql.NullString + for rows.Next() { + err = rows.Scan(&id, &user, &host, &dbName, &command, &stateTime, &state, &info) + if err != nil { + return err + } + + fmt.Printf("id: %d, user: %s, host: %s, dbName: %s, command: %s, time: %d, state: %s, info: %s\n", id, user.String, host.String, dbName.String, command.String, stateTime, state.String, info.String) + + if id != currentConnectionId && user.String == "root" { + connectionIdsToKill = append(connectionIdsToKill, id) + } + } + + if err := rows.Err(); err != nil { + return err + } + + for _, connectionIdToKill := range connectionIdsToKill { + _, err = suite.db.ExecContext(ctx, "KILL ?", connectionIdToKill) + if err != nil { + return err + } + } + + // Bug: We need to wait here for the streamer to reconnect + time.Sleep(time.Second * 2) + + _, err = suite.db.ExecContext(ctx, "INSERT INTO test.testing (id, name) VALUES (3, 'baz')") + if err != nil { + return err + } + + // Bug: Need to write fourth event to hit the canStopStreaming function again + _, err = suite.db.ExecContext(ctx, "INSERT INTO test.testing (id, name) VALUES (4, 'qux')") + if err != nil { + return err + } + + return nil + }) + + err = group.Wait() + suite.Require().NoError(err) + + suite.Require().Len(dmlEvents, 3) +} + +func TestEventsStreamer(t *testing.T) { + suite.Run(t, new(EventsStreamerTestSuite)) +} diff --git a/go/logic/test_utils.go b/go/logic/test_utils.go new file mode 100644 index 000000000..3cf94483a --- /dev/null +++ b/go/logic/test_utils.go @@ -0,0 +1,43 @@ +package logic + +import ( + "context" + "fmt" + + "github.com/github/gh-ost/go/mysql" + "github.com/testcontainers/testcontainers-go" +) + +func GetConnectionConfig(ctx context.Context, container testcontainers.Container) (*mysql.ConnectionConfig, error) { + host, err := container.Host(ctx) + if err != nil { + return nil, err + } + + port, err := container.MappedPort(ctx, "3306") + if err != nil { + return nil, err + } + + connectionConfig := mysql.NewConnectionConfig() + connectionConfig.Key.Hostname = host + connectionConfig.Key.Port = port.Int() + connectionConfig.User = "root" + connectionConfig.Password = "root-password" + + return connectionConfig, nil +} + +func GetDSN(ctx context.Context, container testcontainers.Container) (string, error) { + host, err := container.Host(ctx) + if err != nil { + return "", err + } + + port, err := container.MappedPort(ctx, "3306") + if err != nil { + return "", err + } + + return fmt.Sprintf("root:root-password@tcp(%s:%s)/", host, port.Port()), nil +} diff --git a/tmp/.gitkeep b/tmp/.gitkeep new file mode 100644 index 000000000..e69de29bb diff --git a/vendor/golang.org/x/sync/LICENSE b/vendor/golang.org/x/sync/LICENSE new file mode 100644 index 000000000..2a7cf70da --- /dev/null +++ b/vendor/golang.org/x/sync/LICENSE @@ -0,0 +1,27 @@ +Copyright 2009 The Go Authors. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google LLC nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/golang.org/x/sync/PATENTS b/vendor/golang.org/x/sync/PATENTS new file mode 100644 index 000000000..733099041 --- /dev/null +++ b/vendor/golang.org/x/sync/PATENTS @@ -0,0 +1,22 @@ +Additional IP Rights Grant (Patents) + +"This implementation" means the copyrightable works distributed by +Google as part of the Go project. + +Google hereby grants to You a perpetual, worldwide, non-exclusive, +no-charge, royalty-free, irrevocable (except as stated in this section) +patent license to make, have made, use, offer to sell, sell, import, +transfer and otherwise run, modify and propagate the contents of this +implementation of Go, where such license applies only to those patent +claims, both currently owned or controlled by Google and acquired in +the future, licensable by Google that are necessarily infringed by this +implementation of Go. This grant does not include claims that would be +infringed only as a consequence of further modification of this +implementation. If you or your agent or exclusive licensee institute or +order or agree to the institution of patent litigation against any +entity (including a cross-claim or counterclaim in a lawsuit) alleging +that this implementation of Go or any code incorporated within this +implementation of Go constitutes direct or contributory patent +infringement, or inducement of patent infringement, then any patent +rights granted to you under this License for this implementation of Go +shall terminate as of the date such litigation is filed. diff --git a/vendor/golang.org/x/sync/errgroup/errgroup.go b/vendor/golang.org/x/sync/errgroup/errgroup.go new file mode 100644 index 000000000..948a3ee63 --- /dev/null +++ b/vendor/golang.org/x/sync/errgroup/errgroup.go @@ -0,0 +1,135 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package errgroup provides synchronization, error propagation, and Context +// cancelation for groups of goroutines working on subtasks of a common task. +// +// [errgroup.Group] is related to [sync.WaitGroup] but adds handling of tasks +// returning errors. +package errgroup + +import ( + "context" + "fmt" + "sync" +) + +type token struct{} + +// A Group is a collection of goroutines working on subtasks that are part of +// the same overall task. +// +// A zero Group is valid, has no limit on the number of active goroutines, +// and does not cancel on error. +type Group struct { + cancel func(error) + + wg sync.WaitGroup + + sem chan token + + errOnce sync.Once + err error +} + +func (g *Group) done() { + if g.sem != nil { + <-g.sem + } + g.wg.Done() +} + +// WithContext returns a new Group and an associated Context derived from ctx. +// +// The derived Context is canceled the first time a function passed to Go +// returns a non-nil error or the first time Wait returns, whichever occurs +// first. +func WithContext(ctx context.Context) (*Group, context.Context) { + ctx, cancel := withCancelCause(ctx) + return &Group{cancel: cancel}, ctx +} + +// Wait blocks until all function calls from the Go method have returned, then +// returns the first non-nil error (if any) from them. +func (g *Group) Wait() error { + g.wg.Wait() + if g.cancel != nil { + g.cancel(g.err) + } + return g.err +} + +// Go calls the given function in a new goroutine. +// It blocks until the new goroutine can be added without the number of +// active goroutines in the group exceeding the configured limit. +// +// The first call to return a non-nil error cancels the group's context, if the +// group was created by calling WithContext. The error will be returned by Wait. +func (g *Group) Go(f func() error) { + if g.sem != nil { + g.sem <- token{} + } + + g.wg.Add(1) + go func() { + defer g.done() + + if err := f(); err != nil { + g.errOnce.Do(func() { + g.err = err + if g.cancel != nil { + g.cancel(g.err) + } + }) + } + }() +} + +// TryGo calls the given function in a new goroutine only if the number of +// active goroutines in the group is currently below the configured limit. +// +// The return value reports whether the goroutine was started. +func (g *Group) TryGo(f func() error) bool { + if g.sem != nil { + select { + case g.sem <- token{}: + // Note: this allows barging iff channels in general allow barging. + default: + return false + } + } + + g.wg.Add(1) + go func() { + defer g.done() + + if err := f(); err != nil { + g.errOnce.Do(func() { + g.err = err + if g.cancel != nil { + g.cancel(g.err) + } + }) + } + }() + return true +} + +// SetLimit limits the number of active goroutines in this group to at most n. +// A negative value indicates no limit. +// +// Any subsequent call to the Go method will block until it can add an active +// goroutine without exceeding the configured limit. +// +// The limit must not be modified while any goroutines in the group are active. +func (g *Group) SetLimit(n int) { + if n < 0 { + g.sem = nil + return + } + if len(g.sem) != 0 { + panic(fmt.Errorf("errgroup: modify limit while %v goroutines in the group are still active", len(g.sem))) + } + g.sem = make(chan token, n) +} diff --git a/vendor/golang.org/x/sync/errgroup/go120.go b/vendor/golang.org/x/sync/errgroup/go120.go new file mode 100644 index 000000000..f93c740b6 --- /dev/null +++ b/vendor/golang.org/x/sync/errgroup/go120.go @@ -0,0 +1,13 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.20 + +package errgroup + +import "context" + +func withCancelCause(parent context.Context) (context.Context, func(error)) { + return context.WithCancelCause(parent) +} diff --git a/vendor/golang.org/x/sync/errgroup/pre_go120.go b/vendor/golang.org/x/sync/errgroup/pre_go120.go new file mode 100644 index 000000000..88ce33434 --- /dev/null +++ b/vendor/golang.org/x/sync/errgroup/pre_go120.go @@ -0,0 +1,14 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !go1.20 + +package errgroup + +import "context" + +func withCancelCause(parent context.Context) (context.Context, func(error)) { + ctx, cancel := context.WithCancel(parent) + return ctx, func(error) { cancel() } +} diff --git a/vendor/modules.txt b/vendor/modules.txt index 6d93c80b7..9cda98b36 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -255,6 +255,9 @@ golang.org/x/crypto/ssh/internal/bcrypt_pbkdf # golang.org/x/net v0.24.0 ## explicit; go 1.18 golang.org/x/net/context +# golang.org/x/sync v0.8.0 +## explicit; go 1.18 +golang.org/x/sync/errgroup # golang.org/x/sys v0.21.0 ## explicit; go 1.18 golang.org/x/sys/cpu