diff --git a/pkg/check/check.go b/pkg/check/check.go index 4882715..91bc087 100644 --- a/pkg/check/check.go +++ b/pkg/check/check.go @@ -26,13 +26,14 @@ const ( ) type Resources struct { - DB *sql.DB - Replica *sql.DB - Table *table.TableInfo - Alter string - TargetChunkTime time.Duration - Threads int - ReplicaMaxLag time.Duration + DB *sql.DB + Replica *sql.DB + Table *table.TableInfo + Alter string + TargetChunkTime time.Duration + Threads int + ReplicaMaxLag time.Duration + SkipDropAfterCutover bool // The following resources are only used by the // pre-run checks Host string diff --git a/pkg/check/tablename.go b/pkg/check/tablename.go new file mode 100644 index 0000000..4c62205 --- /dev/null +++ b/pkg/check/tablename.go @@ -0,0 +1,63 @@ +package check + +import ( + "context" + "errors" + "fmt" + "strings" + + "github.com/siddontang/loggers" +) + +const ( + // Max table name length in MySQL + maxTableNameLength = 64 + + // Formats for table names + NameFormatSentinel = "_%s_sentinel" + NameFormatCheckpoint = "_%s_chkpnt" + NameFormatNew = "_%s_new" + NameFormatOld = "_%s_old" + NameFormatOldTimeStamp = "_%s_old_%s" + NameFormatTimestamp = "20060102_150405" +) + +var ( + // The number of extra characters needed for table names with all possible + // formats. These vars are calculated in the `init` function below. + NameFormatNormalExtraChars = 0 + NameFormatTimestampExtraChars = 0 +) + +func init() { + registerCheck("tablename", tableNameCheck, ScopePreflight) + + // Calculate the number of extra characters needed table names with all possible formats + for _, format := range []string{NameFormatSentinel, NameFormatCheckpoint, NameFormatNew, NameFormatOld} { + extraChars := len(strings.Replace(format, "%s", "", -1)) + if extraChars > NameFormatNormalExtraChars { + NameFormatNormalExtraChars = extraChars + } + } + + // Calculate the number of extra characters needed for table names with the old timestamp format + NameFormatTimestampExtraChars = len(strings.Replace(NameFormatOldTimeStamp, "%s", "", -1)) + len(NameFormatTimestamp) +} + +func tableNameCheck(ctx context.Context, r Resources, logger loggers.Advanced) error { + tableName := r.Table.TableName + if len(tableName) < 1 { + return errors.New("table name must be at least 1 character") + } + + timestampTableNameLength := maxTableNameLength - NameFormatTimestampExtraChars + if r.SkipDropAfterCutover && len(tableName) > timestampTableNameLength { + return fmt.Errorf("table name must be less than %d characters when --skip-drop-after-cutover is set", timestampTableNameLength) + } + + normalTableNameLength := maxTableNameLength - NameFormatNormalExtraChars + if len(tableName) > normalTableNameLength { + return fmt.Errorf("table name must be less than %d characters", normalTableNameLength) + } + return nil +} diff --git a/pkg/check/tablename_test.go b/pkg/check/tablename_test.go new file mode 100644 index 0000000..f5dc56b --- /dev/null +++ b/pkg/check/tablename_test.go @@ -0,0 +1,42 @@ +package check + +import ( + "context" + "testing" + + "github.com/cashapp/spirit/pkg/table" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" +) + +func TestCheckTableNameConstants(t *testing.T) { + // Calculated extra chars should always be greater than 0 + assert.Positive(t, NameFormatNormalExtraChars) + assert.Positive(t, NameFormatTimestampExtraChars) + + // Calculated extra chars should be less than the max table name length + assert.Less(t, NameFormatNormalExtraChars, maxTableNameLength) + assert.Less(t, NameFormatTimestampExtraChars, maxTableNameLength) +} + +func TestCheckTableName(t *testing.T) { + testTableName := func(name string, skipDropAfterCutover bool) error { + r := Resources{ + Table: &table.TableInfo{ + TableName: name, + }, + SkipDropAfterCutover: skipDropAfterCutover, + } + return tableNameCheck(context.Background(), r, logrus.New()) + } + + assert.NoError(t, testTableName("a", false)) + assert.NoError(t, testTableName("a", true)) + + assert.ErrorContains(t, testTableName("", false), "table name must be at least 1 character") + assert.ErrorContains(t, testTableName("", true), "table name must be at least 1 character") + + longName := "thisisareallylongtablenamethisisareallylongtablenamethisisareallylongtablename" + assert.ErrorContains(t, testTableName(longName, false), "table name must be less than") + assert.ErrorContains(t, testTableName(longName, true), "table name must be less than") +} diff --git a/pkg/migration/cutover.go b/pkg/migration/cutover.go index 871ea47..e9e6aaa 100644 --- a/pkg/migration/cutover.go +++ b/pkg/migration/cutover.go @@ -14,30 +14,35 @@ import ( ) type CutOver struct { - db *sql.DB - table *table.TableInfo - newTable *table.TableInfo - feed *repl.Client - dbConfig *dbconn.DBConfig - logger loggers.Advanced + db *sql.DB + table *table.TableInfo + newTable *table.TableInfo + oldTableName string + feed *repl.Client + dbConfig *dbconn.DBConfig + logger loggers.Advanced } // NewCutOver contains the logic to perform the final cut over. It requires the original table, // new table, and a replication feed which is used to ensure consistency before the cut over. -func NewCutOver(db *sql.DB, table, newTable *table.TableInfo, feed *repl.Client, dbConfig *dbconn.DBConfig, logger loggers.Advanced) (*CutOver, error) { +func NewCutOver(db *sql.DB, table, newTable *table.TableInfo, oldTableName string, feed *repl.Client, dbConfig *dbconn.DBConfig, logger loggers.Advanced) (*CutOver, error) { if feed == nil { return nil, errors.New("feed must be non-nil") } if table == nil || newTable == nil { return nil, errors.New("table and newTable must be non-nil") } + if oldTableName == "" { + return nil, errors.New("oldTableName must be non-empty") + } return &CutOver{ - db: db, - table: table, - newTable: newTable, - feed: feed, - dbConfig: dbConfig, - logger: logger, + db: db, + table: table, + newTable: newTable, + oldTableName: oldTableName, + feed: feed, + dbConfig: dbConfig, + logger: logger, }, nil } @@ -93,7 +98,7 @@ func (c *CutOver) algorithmRenameUnderLock(ctx context.Context) error { if !c.feed.AllChangesFlushed() { return errors.New("not all changes flushed, final flush might be broken") } - oldName := fmt.Sprintf("_%s_old", c.table.TableName) + oldName := c.oldTableName oldQuotedName := fmt.Sprintf("`%s`.`%s`", c.table.SchemaName, oldName) renameStatement := fmt.Sprintf("RENAME TABLE %s TO %s, %s TO %s", c.table.QuotedName, oldQuotedName, diff --git a/pkg/migration/cutover_test.go b/pkg/migration/cutover_test.go index 8e9e3c5..e7f9095 100644 --- a/pkg/migration/cutover_test.go +++ b/pkg/migration/cutover_test.go @@ -41,6 +41,7 @@ func TestCutOver(t *testing.T) { t1 := table.NewTableInfo(db, "test", "cutovert1") t1new := table.NewTableInfo(db, "test", "_cutovert1_new") + t1old := "_cutovert1_old" logger := logrus.New() cfg, err := mysql.ParseDSN(testutils.DSN()) assert.NoError(t, err) @@ -52,7 +53,7 @@ func TestCutOver(t *testing.T) { // the feed must be started. assert.NoError(t, feed.Run()) - cutover, err := NewCutOver(db, t1, t1new, feed, dbconn.NewDBConfig(), logger) + cutover, err := NewCutOver(db, t1, t1new, t1old, feed, dbconn.NewDBConfig(), logger) assert.NoError(t, err) err = cutover.Run(context.Background()) @@ -100,6 +101,7 @@ func TestMDLLockFails(t *testing.T) { t1 := table.NewTableInfo(db, "test", "mdllocks") t1new := table.NewTableInfo(db, "test", "_mdllocks_new") + t1old := "test_old" logger := logrus.New() cfg, err := mysql.ParseDSN(testutils.DSN()) assert.NoError(t, err) @@ -111,7 +113,7 @@ func TestMDLLockFails(t *testing.T) { // the feed must be started. assert.NoError(t, feed.Run()) - cutover, err := NewCutOver(db, t1, t1new, feed, config, logger) + cutover, err := NewCutOver(db, t1, t1new, t1old, feed, config, logger) assert.NoError(t, err) // Before we cutover, we READ LOCK the table. @@ -135,10 +137,11 @@ func TestInvalidOptions(t *testing.T) { logger := logrus.New() // Invalid options - _, err = NewCutOver(db, nil, nil, nil, dbconn.NewDBConfig(), logger) + _, err = NewCutOver(db, nil, nil, "", nil, dbconn.NewDBConfig(), logger) assert.Error(t, err) t1 := table.NewTableInfo(db, "test", "t1") t1new := table.NewTableInfo(db, "test", "t1_new") + t1old := "test_old" cfg, err := mysql.ParseDSN(testutils.DSN()) assert.NoError(t, err) feed := repl.NewClient(db, cfg.Addr, t1, t1new, cfg.User, cfg.Passwd, &repl.ClientConfig{ @@ -146,6 +149,8 @@ func TestInvalidOptions(t *testing.T) { Concurrency: 4, TargetBatchTime: time.Second, }) - _, err = NewCutOver(db, nil, t1new, feed, dbconn.NewDBConfig(), logger) + _, err = NewCutOver(db, nil, t1new, t1old, feed, dbconn.NewDBConfig(), logger) + assert.Error(t, err) + _, err = NewCutOver(db, nil, t1new, "", feed, dbconn.NewDBConfig(), logger) assert.Error(t, err) } diff --git a/pkg/migration/runner.go b/pkg/migration/runner.go index a81cf8c..8bd30dd 100644 --- a/pkg/migration/runner.go +++ b/pkg/migration/runner.go @@ -285,7 +285,7 @@ func (r *Runner) Run(originalCtx context.Context) error { // It's time for the final cut-over, where // the tables are swapped under a lock. r.setCurrentState(stateCutOver) - cutover, err := NewCutOver(r.db, r.table, r.newTable, r.replClient, r.dbConfig, r.logger) + cutover, err := NewCutOver(r.db, r.table, r.newTable, r.oldTableName(), r.replClient, r.dbConfig, r.logger) if err != nil { return err } @@ -301,10 +301,12 @@ func (r *Runner) Run(originalCtx context.Context) error { if err := r.dropOldTable(ctx); err != nil { // Don't return the error because our automation // will retry the migration (but it's already happened) - r.logger.Errorf("migration successful but failed to drop old table: %v", err) + r.logger.Errorf("migration successful but failed to drop old table: %s - %v", r.oldTableName(), err) } else { - r.logger.Info("successfully dropped old table") + r.logger.Info("successfully dropped old table: ", r.oldTableName()) } + } else { + r.logger.Info("skipped dropping old table: ", r.oldTableName()) } checksumTime := time.Duration(0) if r.checker != nil { @@ -391,9 +393,10 @@ func (r *Runner) runChecks(ctx context.Context, scope check.ScopeFlag) error { ReplicaMaxLag: r.migration.ReplicaMaxLag, // For the pre-run checks we don't have a DB connection yet. // Instead we check the credentials provided. - Host: r.migration.Host, - Username: r.migration.Username, - Password: r.migration.Password, + Host: r.migration.Host, + Username: r.migration.Username, + Password: r.migration.Password, + SkipDropAfterCutover: r.migration.SkipDropAfterCutover, }, r.logger, scope) } @@ -564,10 +567,7 @@ func (r *Runner) dropCheckpoint(ctx context.Context) error { } func (r *Runner) createNewTable(ctx context.Context) error { - newName := fmt.Sprintf("_%s_new", r.table.TableName) - if len(newName) > 64 { - return fmt.Errorf("table name is too long: '%s'. new table name will exceed 64 characters", r.table.TableName) - } + newName := fmt.Sprintf(check.NameFormatNew, r.table.TableName) // drop both if we've decided to call this func. if err := dbconn.Exec(ctx, r.db, "DROP TABLE IF EXISTS %n.%n", r.table.SchemaName, newName); err != nil { return err @@ -603,8 +603,16 @@ func (r *Runner) alterNewTable(ctx context.Context) error { } func (r *Runner) dropOldTable(ctx context.Context) error { - oldName := fmt.Sprintf("_%s_old", r.table.TableName) - return dbconn.Exec(ctx, r.db, "DROP TABLE IF EXISTS %n.%n", r.table.SchemaName, oldName) + return dbconn.Exec(ctx, r.db, "DROP TABLE IF EXISTS %n.%n", r.table.SchemaName, r.oldTableName()) +} + +func (r *Runner) oldTableName() string { + // By default we just set the old table name to __old + // but if they've enabled SkipDropAfterCutover, we add a timestamp + if !r.migration.SkipDropAfterCutover { + return fmt.Sprintf(check.NameFormatOld, r.table.TableName) + } + return fmt.Sprintf(check.NameFormatOldTimeStamp, r.table.TableName, r.startTime.UTC().Format(check.NameFormatTimestamp)) } func (r *Runner) attemptInstantDDL(ctx context.Context) error { @@ -616,7 +624,7 @@ func (r *Runner) attemptInplaceDDL(ctx context.Context) error { } func (r *Runner) createCheckpointTable(ctx context.Context) error { - cpName := fmt.Sprintf("_%s_chkpnt", r.table.TableName) + cpName := fmt.Sprintf(check.NameFormatCheckpoint, r.table.TableName) // drop both if we've decided to call this func. if err := dbconn.Exec(ctx, r.db, "DROP TABLE IF EXISTS %n.%n", r.table.SchemaName, cpName); err != nil { return err @@ -653,7 +661,7 @@ func (r *Runner) GetProgress() Progress { } func (r *Runner) sentinelTableName() string { - return fmt.Sprintf("_%s_sentinel", r.table.TableName) + return fmt.Sprintf(check.NameFormatSentinel, r.table.TableName) } func (r *Runner) createSentinelTable(ctx context.Context) error { @@ -724,8 +732,8 @@ func (r *Runner) resumeFromCheckpoint(ctx context.Context) error { // The objects for these are not available until we confirm // tables exist and we - newName := fmt.Sprintf("_%s_new", r.table.TableName) - cpName := fmt.Sprintf("_%s_chkpnt", r.table.TableName) + newName := fmt.Sprintf(check.NameFormatNew, r.table.TableName) + cpName := fmt.Sprintf(check.NameFormatCheckpoint, r.table.TableName) // Make sure we can read from the new table. if err := dbconn.Exec(ctx, r.db, "SELECT * FROM %n.%n LIMIT 1", diff --git a/pkg/migration/runner_test.go b/pkg/migration/runner_test.go index 75f3c90..59f4be6 100644 --- a/pkg/migration/runner_test.go +++ b/pkg/migration/runner_test.go @@ -2336,7 +2336,6 @@ func TestVarcharE2E(t *testing.T) { func TestSkipDropAfterCutover(t *testing.T) { tableName := `drop_test` - oldName := fmt.Sprintf("_%s_old", tableName) testutils.RunSQL(t, "DROP TABLE IF EXISTS "+tableName) table := fmt.Sprintf(`CREATE TABLE %s ( @@ -2364,7 +2363,7 @@ func TestSkipDropAfterCutover(t *testing.T) { sql := fmt.Sprintf( `SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES - WHERE TABLE_SCHEMA=DATABASE() AND TABLE_NAME='%s'`, oldName) + WHERE TABLE_SCHEMA=DATABASE() AND TABLE_NAME='%s'`, m.oldTableName()) var tableCount int err = m.db.QueryRow(sql).Scan(&tableCount) assert.NoError(t, err) @@ -2376,7 +2375,6 @@ func TestDropAfterCutover(t *testing.T) { sentinelWaitLimit = 10 * time.Second tableName := `drop_test` - oldName := fmt.Sprintf("_%s_old", tableName) testutils.RunSQL(t, "DROP TABLE IF EXISTS "+tableName) table := fmt.Sprintf(`CREATE TABLE %s ( @@ -2404,7 +2402,7 @@ func TestDropAfterCutover(t *testing.T) { sql := fmt.Sprintf( `SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES - WHERE TABLE_SCHEMA=DATABASE() AND TABLE_NAME='%s'`, oldName) + WHERE TABLE_SCHEMA=DATABASE() AND TABLE_NAME='%s'`, m.oldTableName()) var tableCount int err = m.db.QueryRow(sql).Scan(&tableCount) assert.NoError(t, err) @@ -2474,7 +2472,6 @@ func TestDeferCutOverE2E(t *testing.T) { c := make(chan error) tableName := `deferred_cutover_e2e` - oldName := fmt.Sprintf("_%s_old", tableName) sentinelTableName := fmt.Sprintf("_%s_sentinel", tableName) checkpointTableName := fmt.Sprintf("_%s_chkpnt", tableName) @@ -2532,7 +2529,7 @@ func TestDeferCutOverE2E(t *testing.T) { sql := fmt.Sprintf( `SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES - WHERE TABLE_SCHEMA=DATABASE() AND TABLE_NAME='%s'`, oldName) + WHERE TABLE_SCHEMA=DATABASE() AND TABLE_NAME='%s'`, m.oldTableName()) var tableCount int err = db.QueryRow(sql).Scan(&tableCount) assert.NoError(t, err) @@ -2549,7 +2546,6 @@ func TestDeferCutOverE2EBinlogAdvance(t *testing.T) { c := make(chan error) tableName := `deferred_cutover_e2e_stage` - oldName := fmt.Sprintf("_%s_old", tableName) sentinelTableName := fmt.Sprintf("_%s_sentinel", tableName) checkpointTableName := fmt.Sprintf("_%s_chkpnt", tableName) @@ -2611,7 +2607,7 @@ func TestDeferCutOverE2EBinlogAdvance(t *testing.T) { sql := fmt.Sprintf( `SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES - WHERE TABLE_SCHEMA=DATABASE() AND TABLE_NAME='%s'`, oldName) + WHERE TABLE_SCHEMA=DATABASE() AND TABLE_NAME='%s'`, m.oldTableName()) var tableCount int err = db.QueryRow(sql).Scan(&tableCount) assert.NoError(t, err) @@ -2628,7 +2624,7 @@ func TestResumeFromCheckpointE2EWithManualSentinel(t *testing.T) { sentinelWaitLimit = 10 * time.Second statusInterval = 500 * time.Millisecond - tableName := `resume_from_checkpoint_e2e_with_sentinel` + tableName := `resume_checkpoint_e2e_w_sentinel` testutils.RunSQL(t, fmt.Sprintf(`DROP TABLE IF EXISTS %s, _%s_old, _%s_chkpnt, _%s_sentinel`, tableName, tableName, tableName, tableName)) table := fmt.Sprintf(`CREATE TABLE %s ( id int(11) NOT NULL AUTO_INCREMENT, @@ -2667,7 +2663,7 @@ func TestResumeFromCheckpointE2EWithManualSentinel(t *testing.T) { go func() { err := runner.Run(ctx) - assert.Error(t, err) // it gets interrupted as soon as there is a checkpoint saved. + assert.ErrorContains(t, err, "context canceled") // it gets interrupted as soon as there is a checkpoint saved. }() // wait until a checkpoint is saved (which means copy is in progress)