diff --git a/cmd/cmd.go b/cmd/cmd.go index c3aa385..f7a4ff2 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -45,10 +45,11 @@ const ( func newSpannerClient(ctx context.Context, c *cobra.Command) (*spanner.Client, error) { config := &spanner.Config{ - Project: c.Flag(flagNameProject).Value.String(), - Instance: c.Flag(flagNameInstance).Value.String(), - Database: c.Flag(flagNameDatabase).Value.String(), - CredentialsFile: c.Flag(flagCredentialsFile).Value.String(), + Project: c.Flag(flagNameProject).Value.String(), + Instance: c.Flag(flagNameInstance).Value.String(), + Database: c.Flag(flagNameDatabase).Value.String(), + CredentialsFile: c.Flag(flagCredentialsFile).Value.String(), + MigrationTableName: "", // use pkg.spanner default } client, err := spanner.NewClient(ctx, config) diff --git a/cmd/migrate.go b/cmd/migrate.go index 48fd565..811fcdb 100644 --- a/cmd/migrate.go +++ b/cmd/migrate.go @@ -33,7 +33,6 @@ import ( const ( migrationsDirName = "migrations" - migrationTableName = "SchemaMigrations" ) // migrateCmd represents the migrate command @@ -126,7 +125,7 @@ func migrateUp(c *cobra.Command, args []string) error { } defer client.Close() - if err = client.EnsureMigrationTable(ctx, migrationTableName); err != nil { + if err = client.EnsureMigrationTable(ctx); err != nil { return &Error{ cmd: c, err: err, @@ -142,7 +141,7 @@ func migrateUp(c *cobra.Command, args []string) error { } } - return client.ExecuteMigrations(ctx, migrations, limit, migrationTableName) + return client.ExecuteMigrations(ctx, migrations, limit) } func migrateVersion(c *cobra.Command, args []string) error { @@ -154,14 +153,14 @@ func migrateVersion(c *cobra.Command, args []string) error { } defer client.Close() - if err = client.EnsureMigrationTable(ctx, migrationTableName); err != nil { + if err = client.EnsureMigrationTable(ctx); err != nil { return &Error{ cmd: c, err: err, } } - v, _, err := client.GetSchemaMigrationVersion(ctx, migrationTableName) + v, _, err := client.GetSchemaMigrationVersion(ctx) if err != nil { var se *spanner.Error if errors.As(err, &se) && se.Code == spanner.ErrorCodeNoMigration { @@ -202,14 +201,14 @@ func migrateSet(c *cobra.Command, args []string) error { } defer client.Close() - if err = client.EnsureMigrationTable(ctx, migrationTableName); err != nil { + if err = client.EnsureMigrationTable(ctx); err != nil { return &Error{ cmd: c, err: err, } } - if err := client.SetSchemaMigrationVersion(ctx, uint(version), false, migrationTableName); err != nil { + if err := client.SetSchemaMigrationVersion(ctx, uint(version), false); err != nil { return &Error{ cmd: c, err: err, diff --git a/pkg/spanner/client.go b/pkg/spanner/client.go index 81e058c..bfdee4e 100644 --- a/pkg/spanner/client.go +++ b/pkg/spanner/client.go @@ -36,6 +36,7 @@ import ( const ( ddlStatementsSeparator = ";" + defaultMigrationTableName = "SchemaMigrations" ) type table struct { @@ -78,6 +79,13 @@ func NewClient(ctx context.Context, config *Config) (*Client, error) { }, nil } +func (c *Client) migrationTableName() string { + if c.config == nil || c.config.MigrationTableName == "" { + return defaultMigrationTableName + } + return c.config.MigrationTableName +} + func (c *Client) CreateDatabase(ctx context.Context, ddl []byte) error { statements := toStatements(ddl) @@ -131,7 +139,7 @@ func (c *Client) TruncateAllTables(ctx context.Context) error { return err } - if t.TableName == "SchemaMigrations" { + if t.TableName == c.migrationTableName() { return nil } @@ -295,10 +303,10 @@ func (c *Client) ApplyPartitionedDML(ctx context.Context, statements []string, p return numAffectedRows, nil } -func (c *Client) ExecuteMigrations(ctx context.Context, migrations Migrations, limit int, tableName string) error { +func (c *Client) ExecuteMigrations(ctx context.Context, migrations Migrations, limit int) error { sort.Sort(migrations) - version, dirty, err := c.GetSchemaMigrationVersion(ctx, tableName) + version, dirty, err := c.GetSchemaMigrationVersion(ctx) if err != nil { var se *Error if !errors.As(err, &se) || se.Code != ErrorCodeNoMigration { @@ -326,7 +334,7 @@ func (c *Client) ExecuteMigrations(ctx context.Context, migrations Migrations, l continue } - if err := c.SetSchemaMigrationVersion(ctx, m.Version, true, tableName); err != nil { + if err := c.SetSchemaMigrationVersion(ctx, m.Version, true); err != nil { return &Error{ Code: ErrorCodeExecuteMigrations, err: err, @@ -361,7 +369,7 @@ func (c *Client) ExecuteMigrations(ctx context.Context, migrations Migrations, l fmt.Printf("%d/up\n", m.Version) } - if err := c.SetSchemaMigrationVersion(ctx, m.Version, false, tableName); err != nil { + if err := c.SetSchemaMigrationVersion(ctx, m.Version, false); err != nil { return &Error{ Code: ErrorCodeExecuteMigrations, err: err, @@ -381,7 +389,8 @@ func (c *Client) ExecuteMigrations(ctx context.Context, migrations Migrations, l return nil } -func (c *Client) GetSchemaMigrationVersion(ctx context.Context, tableName string) (uint, bool, error) { +func (c *Client) GetSchemaMigrationVersion(ctx context.Context) (uint, bool, error) { + tableName := c.migrationTableName() stmt := spanner.Statement{ SQL: `SELECT Version, Dirty FROM ` + tableName + ` LIMIT 1`, } @@ -416,7 +425,8 @@ func (c *Client) GetSchemaMigrationVersion(ctx context.Context, tableName string return uint(v), dirty, nil } -func (c *Client) SetSchemaMigrationVersion(ctx context.Context, version uint, dirty bool, tableName string) error { +func (c *Client) SetSchemaMigrationVersion(ctx context.Context, version uint, dirty bool) error { + tableName := c.migrationTableName() _, err := c.spannerClient.ReadWriteTransaction(ctx, func(_ context.Context, tx *spanner.ReadWriteTransaction) error { m := []*spanner.Mutation{ spanner.Delete(tableName, spanner.AllKeys()), @@ -438,7 +448,8 @@ func (c *Client) SetSchemaMigrationVersion(ctx context.Context, version uint, di return nil } -func (c *Client) EnsureMigrationTable(ctx context.Context, tableName string) error { +func (c *Client) EnsureMigrationTable(ctx context.Context) error { + tableName := c.migrationTableName() iter := c.spannerClient.Single().Read(ctx, tableName, spanner.AllKeys(), []string{"Version"}) err := iter.Do(func(r *spanner.Row) error { return nil diff --git a/pkg/spanner/client_test.go b/pkg/spanner/client_test.go index 99fbaa4..ab43d16 100644 --- a/pkg/spanner/client_test.go +++ b/pkg/spanner/client_test.go @@ -221,7 +221,7 @@ func TestExecuteMigrations(t *testing.T) { } // only apply 000002.sql by specifying limit 1. - if err := client.ExecuteMigrations(ctx, migrations, 1, migrationTable); err != nil { + if err := client.ExecuteMigrations(ctx, migrations, 1); err != nil { t.Fatalf("failed to execute migration: %v", err) } @@ -229,7 +229,7 @@ func TestExecuteMigrations(t *testing.T) { ensureMigrationColumn(t, ctx, client, "LastName", "STRING(MAX)", "YES") ensureMigrationVersionRecord(t, ctx, client, 2, false) - if err := client.ExecuteMigrations(ctx, migrations, len(migrations), migrationTable); err != nil { + if err := client.ExecuteMigrations(ctx, migrations, len(migrations)); err != nil { t.Fatalf("failed to execute migration: %v", err) } @@ -314,7 +314,7 @@ func TestGetSchemaMigrationVersion(t *testing.T) { t.Fatalf("failed to apply mutation: %v", err) } - v, d, err := client.GetSchemaMigrationVersion(ctx, migrationTable) + v, d, err := client.GetSchemaMigrationVersion(ctx) if err != nil { t.Fatalf("failed to get version: %v", err) } @@ -350,7 +350,7 @@ func TestSetSchemaMigrationVersion(t *testing.T) { nextVersion := 2 nextDirty := true - if err := client.SetSchemaMigrationVersion(ctx, uint(nextVersion), nextDirty, migrationTable); err != nil { + if err := client.SetSchemaMigrationVersion(ctx, uint(nextVersion), nextDirty); err != nil { t.Fatalf("failed to set version: %v", err) } @@ -360,8 +360,6 @@ func TestSetSchemaMigrationVersion(t *testing.T) { func TestEnsureMigrationTable(t *testing.T) { ctx := context.Background() - client, done := testClientWithDatabase(t, ctx) - defer done() tests := map[string]struct { table string @@ -372,7 +370,12 @@ func TestEnsureMigrationTable(t *testing.T) { for name, test := range tests { t.Run(name, func(t *testing.T) { - if err := client.EnsureMigrationTable(ctx, test.table); err != nil { + cfg := &Config{ + MigrationTableName: test.table, + } + client, done := testConfiguredClientWithDatabase(t, ctx, cfg) + defer done() + if err := client.EnsureMigrationTable(ctx); err != nil { t.Fatalf("failed to ensure migration table: %v", err) } @@ -434,7 +437,12 @@ func TestPriorityPBOf(t *testing.T) { } + func testClientWithDatabase(t *testing.T, ctx context.Context) (*Client, func()) { + return testConfiguredClientWithDatabase(t, ctx, &Config{}) +} + +func testConfiguredClientWithDatabase(t *testing.T, ctx context.Context, config *Config) (*Client, func()) { t.Helper() project := os.Getenv(envSpannerProjectID) @@ -454,13 +462,25 @@ func testClientWithDatabase(t *testing.T, ctx context.Context) (*Client, func()) database = fmt.Sprintf("wrench-test-%s", id.String()[:8]) } - config := &Config{ + mergedConfig := &Config{ Project: project, Instance: instance, Database: database, } + if config != nil && config.Project != "" { + mergedConfig.Project = config.Project + } + if config != nil && config.Instance != "" { + mergedConfig.Instance = config.Instance + } + if config != nil && config.Database != "" { + mergedConfig.Database = config.Database + } + if config != nil && config.MigrationTableName != "" { + mergedConfig.MigrationTableName = config.MigrationTableName + } - client, err := NewClient(ctx, config) + client, err := NewClient(ctx, mergedConfig) if err != nil { t.Fatalf("failed to create spanner client: %v", err) } diff --git a/pkg/spanner/config.go b/pkg/spanner/config.go index 7b657ff..80c286d 100644 --- a/pkg/spanner/config.go +++ b/pkg/spanner/config.go @@ -22,10 +22,11 @@ package spanner import "fmt" type Config struct { - Project string - Instance string - Database string - CredentialsFile string + Project string + Instance string + Database string + CredentialsFile string + MigrationTableName string } func (c *Config) URL() string {