Skip to content

Commit

Permalink
Merge pull request #307 from cashapp/jayj/old-table-timestamped
Browse files Browse the repository at this point in the history
Jayj/old table timestamped
  • Loading branch information
jayjanssen authored Jun 27, 2024
2 parents 8c08e63 + 5d6c20c commit 77dca2b
Show file tree
Hide file tree
Showing 7 changed files with 171 additions and 51 deletions.
15 changes: 8 additions & 7 deletions pkg/check/check.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,14 @@ const (
)

type Resources struct {
DB *sql.DB
Replica *sql.DB
Table *table.TableInfo
Alter string
TargetChunkTime time.Duration
Threads int
ReplicaMaxLag time.Duration
DB *sql.DB
Replica *sql.DB
Table *table.TableInfo
Alter string
TargetChunkTime time.Duration
Threads int
ReplicaMaxLag time.Duration
SkipDropAfterCutover bool
// The following resources are only used by the
// pre-run checks
Host string
Expand Down
63 changes: 63 additions & 0 deletions pkg/check/tablename.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package check

import (
"context"
"errors"
"fmt"
"strings"

"github.com/siddontang/loggers"
)

const (
// Max table name length in MySQL
maxTableNameLength = 64

// Formats for table names
NameFormatSentinel = "_%s_sentinel"
NameFormatCheckpoint = "_%s_chkpnt"
NameFormatNew = "_%s_new"
NameFormatOld = "_%s_old"
NameFormatOldTimeStamp = "_%s_old_%s"
NameFormatTimestamp = "20060102_150405"
)

var (
// The number of extra characters needed for table names with all possible
// formats. These vars are calculated in the `init` function below.
NameFormatNormalExtraChars = 0
NameFormatTimestampExtraChars = 0
)

func init() {
registerCheck("tablename", tableNameCheck, ScopePreflight)

// Calculate the number of extra characters needed table names with all possible formats
for _, format := range []string{NameFormatSentinel, NameFormatCheckpoint, NameFormatNew, NameFormatOld} {
extraChars := len(strings.Replace(format, "%s", "", -1))
if extraChars > NameFormatNormalExtraChars {
NameFormatNormalExtraChars = extraChars
}
}

// Calculate the number of extra characters needed for table names with the old timestamp format
NameFormatTimestampExtraChars = len(strings.Replace(NameFormatOldTimeStamp, "%s", "", -1)) + len(NameFormatTimestamp)
}

func tableNameCheck(ctx context.Context, r Resources, logger loggers.Advanced) error {
tableName := r.Table.TableName
if len(tableName) < 1 {
return errors.New("table name must be at least 1 character")
}

timestampTableNameLength := maxTableNameLength - NameFormatTimestampExtraChars
if r.SkipDropAfterCutover && len(tableName) > timestampTableNameLength {
return fmt.Errorf("table name must be less than %d characters when --skip-drop-after-cutover is set", timestampTableNameLength)
}

normalTableNameLength := maxTableNameLength - NameFormatNormalExtraChars
if len(tableName) > normalTableNameLength {
return fmt.Errorf("table name must be less than %d characters", normalTableNameLength)
}
return nil
}
42 changes: 42 additions & 0 deletions pkg/check/tablename_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package check

import (
"context"
"testing"

"github.com/cashapp/spirit/pkg/table"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
)

func TestCheckTableNameConstants(t *testing.T) {
// Calculated extra chars should always be greater than 0
assert.Positive(t, NameFormatNormalExtraChars)
assert.Positive(t, NameFormatTimestampExtraChars)

// Calculated extra chars should be less than the max table name length
assert.Less(t, NameFormatNormalExtraChars, maxTableNameLength)
assert.Less(t, NameFormatTimestampExtraChars, maxTableNameLength)
}

func TestCheckTableName(t *testing.T) {
testTableName := func(name string, skipDropAfterCutover bool) error {
r := Resources{
Table: &table.TableInfo{
TableName: name,
},
SkipDropAfterCutover: skipDropAfterCutover,
}
return tableNameCheck(context.Background(), r, logrus.New())
}

assert.NoError(t, testTableName("a", false))
assert.NoError(t, testTableName("a", true))

assert.ErrorContains(t, testTableName("", false), "table name must be at least 1 character")
assert.ErrorContains(t, testTableName("", true), "table name must be at least 1 character")

longName := "thisisareallylongtablenamethisisareallylongtablenamethisisareallylongtablename"
assert.ErrorContains(t, testTableName(longName, false), "table name must be less than")
assert.ErrorContains(t, testTableName(longName, true), "table name must be less than")
}
33 changes: 19 additions & 14 deletions pkg/migration/cutover.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,35 @@ import (
)

type CutOver struct {
db *sql.DB
table *table.TableInfo
newTable *table.TableInfo
feed *repl.Client
dbConfig *dbconn.DBConfig
logger loggers.Advanced
db *sql.DB
table *table.TableInfo
newTable *table.TableInfo
oldTableName string
feed *repl.Client
dbConfig *dbconn.DBConfig
logger loggers.Advanced
}

// NewCutOver contains the logic to perform the final cut over. It requires the original table,
// new table, and a replication feed which is used to ensure consistency before the cut over.
func NewCutOver(db *sql.DB, table, newTable *table.TableInfo, feed *repl.Client, dbConfig *dbconn.DBConfig, logger loggers.Advanced) (*CutOver, error) {
func NewCutOver(db *sql.DB, table, newTable *table.TableInfo, oldTableName string, feed *repl.Client, dbConfig *dbconn.DBConfig, logger loggers.Advanced) (*CutOver, error) {
if feed == nil {
return nil, errors.New("feed must be non-nil")
}
if table == nil || newTable == nil {
return nil, errors.New("table and newTable must be non-nil")
}
if oldTableName == "" {
return nil, errors.New("oldTableName must be non-empty")
}
return &CutOver{
db: db,
table: table,
newTable: newTable,
feed: feed,
dbConfig: dbConfig,
logger: logger,
db: db,
table: table,
newTable: newTable,
oldTableName: oldTableName,
feed: feed,
dbConfig: dbConfig,
logger: logger,
}, nil
}

Expand Down Expand Up @@ -93,7 +98,7 @@ func (c *CutOver) algorithmRenameUnderLock(ctx context.Context) error {
if !c.feed.AllChangesFlushed() {
return errors.New("not all changes flushed, final flush might be broken")
}
oldName := fmt.Sprintf("_%s_old", c.table.TableName)
oldName := c.oldTableName
oldQuotedName := fmt.Sprintf("`%s`.`%s`", c.table.SchemaName, oldName)
renameStatement := fmt.Sprintf("RENAME TABLE %s TO %s, %s TO %s",
c.table.QuotedName, oldQuotedName,
Expand Down
13 changes: 9 additions & 4 deletions pkg/migration/cutover_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ func TestCutOver(t *testing.T) {

t1 := table.NewTableInfo(db, "test", "cutovert1")
t1new := table.NewTableInfo(db, "test", "_cutovert1_new")
t1old := "_cutovert1_old"
logger := logrus.New()
cfg, err := mysql.ParseDSN(testutils.DSN())
assert.NoError(t, err)
Expand All @@ -52,7 +53,7 @@ func TestCutOver(t *testing.T) {
// the feed must be started.
assert.NoError(t, feed.Run())

cutover, err := NewCutOver(db, t1, t1new, feed, dbconn.NewDBConfig(), logger)
cutover, err := NewCutOver(db, t1, t1new, t1old, feed, dbconn.NewDBConfig(), logger)
assert.NoError(t, err)

err = cutover.Run(context.Background())
Expand Down Expand Up @@ -100,6 +101,7 @@ func TestMDLLockFails(t *testing.T) {

t1 := table.NewTableInfo(db, "test", "mdllocks")
t1new := table.NewTableInfo(db, "test", "_mdllocks_new")
t1old := "test_old"
logger := logrus.New()
cfg, err := mysql.ParseDSN(testutils.DSN())
assert.NoError(t, err)
Expand All @@ -111,7 +113,7 @@ func TestMDLLockFails(t *testing.T) {
// the feed must be started.
assert.NoError(t, feed.Run())

cutover, err := NewCutOver(db, t1, t1new, feed, config, logger)
cutover, err := NewCutOver(db, t1, t1new, t1old, feed, config, logger)
assert.NoError(t, err)

// Before we cutover, we READ LOCK the table.
Expand All @@ -135,17 +137,20 @@ func TestInvalidOptions(t *testing.T) {
logger := logrus.New()

// Invalid options
_, err = NewCutOver(db, nil, nil, nil, dbconn.NewDBConfig(), logger)
_, err = NewCutOver(db, nil, nil, "", nil, dbconn.NewDBConfig(), logger)
assert.Error(t, err)
t1 := table.NewTableInfo(db, "test", "t1")
t1new := table.NewTableInfo(db, "test", "t1_new")
t1old := "test_old"
cfg, err := mysql.ParseDSN(testutils.DSN())
assert.NoError(t, err)
feed := repl.NewClient(db, cfg.Addr, t1, t1new, cfg.User, cfg.Passwd, &repl.ClientConfig{
Logger: logger,
Concurrency: 4,
TargetBatchTime: time.Second,
})
_, err = NewCutOver(db, nil, t1new, feed, dbconn.NewDBConfig(), logger)
_, err = NewCutOver(db, nil, t1new, t1old, feed, dbconn.NewDBConfig(), logger)
assert.Error(t, err)
_, err = NewCutOver(db, nil, t1new, "", feed, dbconn.NewDBConfig(), logger)
assert.Error(t, err)
}
40 changes: 24 additions & 16 deletions pkg/migration/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ func (r *Runner) Run(originalCtx context.Context) error {
// It's time for the final cut-over, where
// the tables are swapped under a lock.
r.setCurrentState(stateCutOver)
cutover, err := NewCutOver(r.db, r.table, r.newTable, r.replClient, r.dbConfig, r.logger)
cutover, err := NewCutOver(r.db, r.table, r.newTable, r.oldTableName(), r.replClient, r.dbConfig, r.logger)
if err != nil {
return err
}
Expand All @@ -301,10 +301,12 @@ func (r *Runner) Run(originalCtx context.Context) error {
if err := r.dropOldTable(ctx); err != nil {
// Don't return the error because our automation
// will retry the migration (but it's already happened)
r.logger.Errorf("migration successful but failed to drop old table: %v", err)
r.logger.Errorf("migration successful but failed to drop old table: %s - %v", r.oldTableName(), err)
} else {
r.logger.Info("successfully dropped old table")
r.logger.Info("successfully dropped old table: ", r.oldTableName())
}
} else {
r.logger.Info("skipped dropping old table: ", r.oldTableName())
}
checksumTime := time.Duration(0)
if r.checker != nil {
Expand Down Expand Up @@ -391,9 +393,10 @@ func (r *Runner) runChecks(ctx context.Context, scope check.ScopeFlag) error {
ReplicaMaxLag: r.migration.ReplicaMaxLag,
// For the pre-run checks we don't have a DB connection yet.
// Instead we check the credentials provided.
Host: r.migration.Host,
Username: r.migration.Username,
Password: r.migration.Password,
Host: r.migration.Host,
Username: r.migration.Username,
Password: r.migration.Password,
SkipDropAfterCutover: r.migration.SkipDropAfterCutover,
}, r.logger, scope)
}

Expand Down Expand Up @@ -564,10 +567,7 @@ func (r *Runner) dropCheckpoint(ctx context.Context) error {
}

func (r *Runner) createNewTable(ctx context.Context) error {
newName := fmt.Sprintf("_%s_new", r.table.TableName)
if len(newName) > 64 {
return fmt.Errorf("table name is too long: '%s'. new table name will exceed 64 characters", r.table.TableName)
}
newName := fmt.Sprintf(check.NameFormatNew, r.table.TableName)
// drop both if we've decided to call this func.
if err := dbconn.Exec(ctx, r.db, "DROP TABLE IF EXISTS %n.%n", r.table.SchemaName, newName); err != nil {
return err
Expand Down Expand Up @@ -603,8 +603,16 @@ func (r *Runner) alterNewTable(ctx context.Context) error {
}

func (r *Runner) dropOldTable(ctx context.Context) error {
oldName := fmt.Sprintf("_%s_old", r.table.TableName)
return dbconn.Exec(ctx, r.db, "DROP TABLE IF EXISTS %n.%n", r.table.SchemaName, oldName)
return dbconn.Exec(ctx, r.db, "DROP TABLE IF EXISTS %n.%n", r.table.SchemaName, r.oldTableName())
}

func (r *Runner) oldTableName() string {
// By default we just set the old table name to _<table>_old
// but if they've enabled SkipDropAfterCutover, we add a timestamp
if !r.migration.SkipDropAfterCutover {
return fmt.Sprintf(check.NameFormatOld, r.table.TableName)
}
return fmt.Sprintf(check.NameFormatOldTimeStamp, r.table.TableName, r.startTime.UTC().Format(check.NameFormatTimestamp))
}

func (r *Runner) attemptInstantDDL(ctx context.Context) error {
Expand All @@ -616,7 +624,7 @@ func (r *Runner) attemptInplaceDDL(ctx context.Context) error {
}

func (r *Runner) createCheckpointTable(ctx context.Context) error {
cpName := fmt.Sprintf("_%s_chkpnt", r.table.TableName)
cpName := fmt.Sprintf(check.NameFormatCheckpoint, r.table.TableName)
// drop both if we've decided to call this func.
if err := dbconn.Exec(ctx, r.db, "DROP TABLE IF EXISTS %n.%n", r.table.SchemaName, cpName); err != nil {
return err
Expand Down Expand Up @@ -653,7 +661,7 @@ func (r *Runner) GetProgress() Progress {
}

func (r *Runner) sentinelTableName() string {
return fmt.Sprintf("_%s_sentinel", r.table.TableName)
return fmt.Sprintf(check.NameFormatSentinel, r.table.TableName)
}

func (r *Runner) createSentinelTable(ctx context.Context) error {
Expand Down Expand Up @@ -724,8 +732,8 @@ func (r *Runner) resumeFromCheckpoint(ctx context.Context) error {

// The objects for these are not available until we confirm
// tables exist and we
newName := fmt.Sprintf("_%s_new", r.table.TableName)
cpName := fmt.Sprintf("_%s_chkpnt", r.table.TableName)
newName := fmt.Sprintf(check.NameFormatNew, r.table.TableName)
cpName := fmt.Sprintf(check.NameFormatCheckpoint, r.table.TableName)

// Make sure we can read from the new table.
if err := dbconn.Exec(ctx, r.db, "SELECT * FROM %n.%n LIMIT 1",
Expand Down
Loading

0 comments on commit 77dca2b

Please sign in to comment.