diff --git a/internal/db/declarative/debug.go b/internal/db/declarative/debug.go index f274ed39dd..30ffaa4c7b 100644 --- a/internal/db/declarative/debug.go +++ b/internal/db/declarative/debug.go @@ -104,9 +104,9 @@ func CollectMigrationsList(fsys afero.Fs) []string { if err != nil { return nil } - // Strip directory prefix to return just filenames + // Strip directory prefix to return display names for i, m := range migrations { - migrations[i] = filepath.Base(m) + migrations[i] = migration.MigrationName(m) } return migrations } diff --git a/internal/db/push/push.go b/internal/db/push/push.go index 4084cd0800..79af124f33 100644 --- a/internal/db/push/push.go +++ b/internal/db/push/push.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "os" - "path/filepath" "github.com/go-errors/errors" "github.com/jackc/pgconn" @@ -118,8 +117,7 @@ func Run(ctx context.Context, dryRun, ignoreVersionMismatch bool, includeRoles, func confirmPushAll(pending []string) (msg string) { for _, path := range pending { - filename := filepath.Base(path) - msg += fmt.Sprintf(" • %s\n", utils.Bold(filename)) + msg += fmt.Sprintf(" • %s\n", utils.Bold(migration.MigrationName(path))) } return msg } diff --git a/internal/migration/repair/repair.go b/internal/migration/repair/repair.go index a7dfee274e..1ab6fd34d6 100644 --- a/internal/migration/repair/repair.go +++ b/internal/migration/repair/repair.go @@ -88,15 +88,28 @@ func UpdateMigrationTable(ctx context.Context, conn *pgx.Conn, version []string, } func GetMigrationFile(version string, fsys afero.Fs) (string, error) { + // Try flat file first: version_*.sql path := filepath.Join(utils.MigrationsDir, version+"_*.sql") matches, err := afero.Glob(fsys, path) if err != nil { return "", errors.Errorf("failed to glob migration files: %w", err) } - if len(matches) == 0 { - return "", errors.Errorf("glob %s: %w", path, os.ErrNotExist) + if len(matches) > 0 { + return matches[0], nil } - return matches[0], nil + // Try folder-based migration: version_*/*.sql + dirPath := filepath.Join(utils.MigrationsDir, version+"_*", "*.sql") + dirMatches, err := afero.Glob(fsys, dirPath) + if err != nil { + return "", errors.Errorf("failed to glob migration directories: %w", err) + } + if len(dirMatches) == 1 { + return dirMatches[0], nil + } + if len(dirMatches) > 1 { + return "", errors.Errorf("multiple .sql files found for version %s", version) + } + return "", errors.Errorf("no migration found for version %s: %w", version, os.ErrNotExist) } func NewMigrationFromVersion(version string, fsys afero.Fs) (*migration.MigrationFile, error) { diff --git a/internal/migration/repair/repair_test.go b/internal/migration/repair/repair_test.go index f8eb33cce8..b415069c5d 100644 --- a/internal/migration/repair/repair_test.go +++ b/internal/migration/repair/repair_test.go @@ -61,6 +61,23 @@ func TestRepairCommand(t *testing.T) { assert.NoError(t, err) }) + t.Run("applies folder-based migration version", func(t *testing.T) { + // Setup in-memory fs + fsys := afero.NewMemMapFs() + sqlPath := filepath.Join(utils.MigrationsDir, "20242409125510_premium_mister_fear", "schema.sql") + require.NoError(t, afero.WriteFile(fsys, sqlPath, []byte("select 1"), 0644)) + // Setup mock postgres + conn := pgtest.NewConn() + defer conn.Close(t) + helper.MockMigrationHistory(conn). + Query(migration.UPSERT_MIGRATION_VERSION, "20242409125510", "premium_mister_fear", []string{"select 1"}). + Reply("INSERT 0 1") + // Run test + err := Run(context.Background(), dbConfig, []string{"20242409125510"}, Applied, fsys, conn.Intercept) + // Check error + assert.NoError(t, err) + }) + t.Run("throws error on invalid version", func(t *testing.T) { // Setup in-memory fs fsys := afero.NewMemMapFs() @@ -97,6 +114,45 @@ func TestRepairCommand(t *testing.T) { }) } +func TestGetMigrationFile(t *testing.T) { + t.Run("finds flat migration file", func(t *testing.T) { + fsys := afero.NewMemMapFs() + path := filepath.Join(utils.MigrationsDir, "0_test.sql") + require.NoError(t, afero.WriteFile(fsys, path, []byte("select 1"), 0644)) + // Run test + result, err := GetMigrationFile("0", fsys) + assert.NoError(t, err) + assert.Equal(t, path, result) + }) + + t.Run("finds folder-based migration file", func(t *testing.T) { + fsys := afero.NewMemMapFs() + sqlPath := filepath.Join(utils.MigrationsDir, "20242409125510_premium_mister_fear", "schema.sql") + require.NoError(t, afero.WriteFile(fsys, sqlPath, []byte("select 1"), 0644)) + // Run test + result, err := GetMigrationFile("20242409125510", fsys) + assert.NoError(t, err) + assert.Equal(t, sqlPath, result) + }) + + t.Run("returns error for multiple .sql files in directory", func(t *testing.T) { + fsys := afero.NewMemMapFs() + dir := filepath.Join(utils.MigrationsDir, "20242409125510_premium_mister_fear") + require.NoError(t, afero.WriteFile(fsys, filepath.Join(dir, "schema.sql"), []byte("select 1"), 0644)) + require.NoError(t, afero.WriteFile(fsys, filepath.Join(dir, "extra.sql"), []byte("select 2"), 0644)) + // Run test + _, err := GetMigrationFile("20242409125510", fsys) + assert.ErrorContains(t, err, "multiple .sql files found") + }) + + t.Run("returns error when version not found", func(t *testing.T) { + fsys := afero.NewMemMapFs() + // Run test + _, err := GetMigrationFile("99999", fsys) + assert.ErrorIs(t, err, os.ErrNotExist) + }) +} + func TestRepairAll(t *testing.T) { t.Run("repairs whole history", func(t *testing.T) { t.Cleanup(fstest.MockStdin(t, "y")) diff --git a/internal/migration/squash/squash.go b/internal/migration/squash/squash.go index afc52c687e..53e63e9494 100644 --- a/internal/migration/squash/squash.go +++ b/internal/migration/squash/squash.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "os" + "path/filepath" "strconv" "time" @@ -73,6 +74,14 @@ func squashToVersion(ctx context.Context, version string, fsys afero.Fs, options for _, path := range migrations[:len(migrations)-1] { if err := fsys.Remove(path); err != nil { fmt.Fprintln(os.Stderr, err) + continue + } + // For folder-based migrations, remove the parent directory and all its contents + dir := filepath.Dir(path) + if dir != utils.MigrationsDir { + if err := fsys.RemoveAll(dir); err != nil { + fmt.Fprintln(os.Stderr, err) + } } } return nil diff --git a/pkg/migration/apply.go b/pkg/migration/apply.go index e40b58be2e..80cdebc487 100644 --- a/pkg/migration/apply.go +++ b/pkg/migration/apply.go @@ -5,7 +5,6 @@ import ( "fmt" "io/fs" "os" - "path/filepath" "github.com/go-errors/errors" "github.com/jackc/pgx/v4" @@ -23,9 +22,8 @@ func FindPendingMigrations(localMigrations, remoteMigrations []string) ([]string i, j := 0, 0 for i < len(remoteMigrations) && j < len(localMigrations) { remote := remoteMigrations[i] - filename := filepath.Base(localMigrations[j]) - // Check if migration has been applied before, LoadLocalMigrations guarantees a match - local := migrateFilePattern.FindStringSubmatch(filename)[1] + // Extract version from path, supporting both flat files and folder-based migrations + local, _, _ := ParseVersion(localMigrations[j]) if remote == local { j++ i++ @@ -60,8 +58,7 @@ func ApplyMigrations(ctx context.Context, pending []string, conn *pgx.Conn, fsys } } for _, path := range pending { - filename := filepath.Base(path) - fmt.Fprintf(os.Stderr, "Applying migration %s...\n", filename) + fmt.Fprintf(os.Stderr, "Applying migration %s...\n", MigrationName(path)) // Reset all connection settings that might have been modified by another statement on the same connection // eg: `SELECT pg_catalog.set_config('search_path', '', false);` if _, err := conn.Exec(ctx, "RESET ALL"); err != nil { diff --git a/pkg/migration/apply_test.go b/pkg/migration/apply_test.go index e6df97721f..6fbe43e806 100644 --- a/pkg/migration/apply_test.go +++ b/pkg/migration/apply_test.go @@ -91,6 +91,55 @@ func TestPendingMigrations(t *testing.T) { assert.ErrorIs(t, err, ErrMissingLocal) assert.ElementsMatch(t, []string{remote[1], remote[3], remote[4]}, missing) }) + + t.Run("finds pending folder-based migrations", func(t *testing.T) { + local := []string{ + "20221201000000_test.sql", + "20221201000001_test.sql", + "20221201000002_create_users/schema.sql", + "20221201000003_add_indexes/schema.sql", + } + remote := []string{ + "20221201000000", + "20221201000001", + } + // Run test + pending, err := FindPendingMigrations(local, remote) + // Check error + assert.NoError(t, err) + assert.ElementsMatch(t, local[2:], pending) + }) + + t.Run("matches remote with local folder-based migration", func(t *testing.T) { + local := []string{ + "20221201000000_test.sql", + "20221201000001_create_users/schema.sql", + } + remote := []string{ + "20221201000000", + "20221201000001", + } + // Run test + pending, err := FindPendingMigrations(local, remote) + // Check error + assert.NoError(t, err) + assert.Empty(t, pending) + }) + + t.Run("detects missing local for folder-based migrations", func(t *testing.T) { + local := []string{ + "20221201000000_create_users/schema.sql", + } + remote := []string{ + "20221201000000", + "20221201000001", + } + // Run test + missing, err := FindPendingMigrations(local, remote) + // Check error + assert.ErrorIs(t, err, ErrMissingLocal) + assert.ElementsMatch(t, []string{"20221201000001"}, missing) + }) } var ( diff --git a/pkg/migration/file.go b/pkg/migration/file.go index 540c129e33..674d32bb3c 100644 --- a/pkg/migration/file.go +++ b/pkg/migration/file.go @@ -28,6 +28,7 @@ type MigrationFile struct { var ( migrateFilePattern = regexp.MustCompile(`^([0-9]+)_(.*)\.sql$`) + migrateDirPattern = regexp.MustCompile(`^([0-9]+)_(.+)$`) typeNamePattern = regexp.MustCompile(`type "([^"]+)" does not exist`) ) @@ -37,16 +38,43 @@ func NewMigrationFromFile(path string, fsys fs.FS) (*MigrationFile, error) { return nil, err } file := MigrationFile{Statements: lines} - // Parse version from file name - filename := filepath.Base(path) - matches := migrateFilePattern.FindStringSubmatch(filename) - if len(matches) > 2 { - file.Version = matches[1] - file.Name = matches[2] + // Parse version from file path (supports both flat files and folder-based migrations) + if version, name, ok := ParseVersion(path); ok { + file.Version = version + file.Name = name } return &file, nil } +// ParseVersion extracts the version and name from a migration path. +// Handles both flat files (20220727064247_create_table.sql) and +// folder-based migrations (20242409125510_premium_mister_fear/.sql). +func ParseVersion(path string) (version, name string, ok bool) { + filename := filepath.Base(path) + if matches := migrateFilePattern.FindStringSubmatch(filename); len(matches) > 2 { + return matches[1], matches[2], true + } + // Try parent directory for folder-based migrations + dirName := filepath.Base(filepath.Dir(path)) + if matches := migrateDirPattern.FindStringSubmatch(dirName); len(matches) > 2 { + return matches[1], matches[2], true + } + return "", "", false +} + +// MigrationName returns a human-readable display name for a migration path. +// For flat files: "20220727064247_create_table.sql" +// For folder migrations: "20242409125510_premium_mister_fear/.sql" +func MigrationName(path string) string { + filename := filepath.Base(path) + if migrateFilePattern.MatchString(filename) { + return filename + } + // For folder-based migrations, show "dirname/filename" + dir := filepath.Base(filepath.Dir(path)) + return filepath.Join(dir, filename) +} + func parseFile(path string, fsys fs.FS) ([]string, error) { sql, err := fsys.Open(path) if err != nil { diff --git a/pkg/migration/file_test.go b/pkg/migration/file_test.go index 703f26954c..fd137725db 100644 --- a/pkg/migration/file_test.go +++ b/pkg/migration/file_test.go @@ -100,6 +100,21 @@ func TestMigrationFile(t *testing.T) { assert.ErrorContains(t, err, "At statement: 0") }) + t.Run("new from folder-based migration", func(t *testing.T) { + // Setup in-memory fs + fsys := fs.MapFS{ + "20242409125510_premium_mister_fear/schema.sql": &fs.MapFile{Data: []byte("CREATE TABLE foo (id int)")}, + "20242409125510_premium_mister_fear/snapshot.json": &fs.MapFile{Data: []byte("{}")}, + } + // Run test + migration, err := NewMigrationFromFile("20242409125510_premium_mister_fear/schema.sql", fsys) + // Check error + assert.NoError(t, err) + assert.Equal(t, "20242409125510", migration.Version) + assert.Equal(t, "premium_mister_fear", migration.Name) + assert.Len(t, migration.Statements, 1) + }) + t.Run("skips hint for schema-qualified type errors", func(t *testing.T) { migration := MigrationFile{ Statements: []string{"CREATE TABLE test (path extensions.ltree NOT NULL)"}, @@ -158,3 +173,60 @@ func TestIsSchemaQualified(t *testing.T) { assert.False(t, IsSchemaQualified("ltree")) assert.False(t, IsSchemaQualified("")) } + +func TestParseVersion(t *testing.T) { + t.Run("extracts version from flat file", func(t *testing.T) { + version, name, ok := ParseVersion("20220727064247_create_table.sql") + assert.True(t, ok) + assert.Equal(t, "20220727064247", version) + assert.Equal(t, "create_table", name) + }) + + t.Run("extracts version from flat file with path", func(t *testing.T) { + version, name, ok := ParseVersion("supabase/migrations/20220727064247_create_table.sql") + assert.True(t, ok) + assert.Equal(t, "20220727064247", version) + assert.Equal(t, "create_table", name) + }) + + t.Run("extracts version from folder-based migration", func(t *testing.T) { + version, name, ok := ParseVersion("supabase/migrations/20242409125510_premium_mister_fear/schema.sql") + assert.True(t, ok) + assert.Equal(t, "20242409125510", version) + assert.Equal(t, "premium_mister_fear", name) + }) + + t.Run("extracts version from folder-based migration without parent path", func(t *testing.T) { + version, name, ok := ParseVersion("20242409125510_premium_mister_fear/schema.sql") + assert.True(t, ok) + assert.Equal(t, "20242409125510", version) + assert.Equal(t, "premium_mister_fear", name) + }) + + t.Run("returns false for non-matching path", func(t *testing.T) { + _, _, ok := ParseVersion("random_file.txt") + assert.False(t, ok) + }) + + t.Run("returns false for .sql without matching parent dir", func(t *testing.T) { + _, _, ok := ParseVersion("some_dir/schema.sql") + assert.False(t, ok) + }) +} + +func TestMigrationName(t *testing.T) { + t.Run("returns filename for flat migration", func(t *testing.T) { + assert.Equal(t, "20220727064247_create_table.sql", MigrationName("supabase/migrations/20220727064247_create_table.sql")) + }) + + t.Run("returns dir/file for folder-based migration", func(t *testing.T) { + assert.Equal(t, + "20242409125510_premium_mister_fear/schema.sql", + MigrationName("supabase/migrations/20242409125510_premium_mister_fear/schema.sql"), + ) + }) + + t.Run("returns filename when no parent directory", func(t *testing.T) { + assert.Equal(t, "20220727064247_create_table.sql", MigrationName("20220727064247_create_table.sql")) + }) +} diff --git a/pkg/migration/list.go b/pkg/migration/list.go index 8972bb2594..ce36312a3b 100644 --- a/pkg/migration/list.go +++ b/pkg/migration/list.go @@ -37,23 +37,55 @@ func ListLocalMigrations(migrationsDir string, fsys fs.FS, filter ...func(string } var clean []string OUTER: - for i, migration := range localMigrations { - if migration.IsDir() { - continue - } - filename := migration.Name() - if i == 0 && shouldSkip(filename) { - fmt.Fprintf(os.Stderr, "Skipping migration %s... (replace \"init\" with a different file name to apply this migration)\n", filename) - continue - } - matches := migrateFilePattern.FindStringSubmatch(filename) - if len(matches) == 0 { - fmt.Fprintf(os.Stderr, "Skipping migration %s... (file name must match pattern \"_name.sql\")\n", filename) - continue + for i, entry := range localMigrations { + var path, version string + + if entry.IsDir() { + dirName := entry.Name() + matches := migrateDirPattern.FindStringSubmatch(dirName) + if len(matches) == 0 { + continue + } + // Look for exactly one .sql file inside the directory + dirPath := filepath.Join(migrationsDir, dirName) + entries, err := fs.ReadDir(fsys, dirPath) + if err != nil { + fmt.Fprintf(os.Stderr, "Skipping migration directory %s... (%v)\n", dirName, err) + continue + } + var sqlFiles []string + for _, e := range entries { + if !e.IsDir() && filepath.Ext(e.Name()) == ".sql" { + sqlFiles = append(sqlFiles, e.Name()) + } + } + if len(sqlFiles) != 1 { + if len(sqlFiles) == 0 { + fmt.Fprintf(os.Stderr, "Skipping migration directory %s... (no .sql file found)\n", dirName) + } else { + fmt.Fprintf(os.Stderr, "Skipping migration directory %s... (multiple .sql files found)\n", dirName) + } + continue + } + path = filepath.Join(migrationsDir, dirName, sqlFiles[0]) + version = matches[1] + } else { + filename := entry.Name() + if i == 0 && shouldSkip(filename) { + fmt.Fprintf(os.Stderr, "Skipping migration %s... (replace \"init\" with a different file name to apply this migration)\n", filename) + continue + } + matches := migrateFilePattern.FindStringSubmatch(filename) + if len(matches) == 0 { + fmt.Fprintf(os.Stderr, "Skipping migration %s... (file name must match pattern \"_name.sql\")\n", filename) + continue + } + path = filepath.Join(migrationsDir, filename) + version = matches[1] } - path := filepath.Join(migrationsDir, filename) + for _, keep := range filter { - if version := matches[1]; !keep(version) { + if !keep(version) { continue OUTER } } diff --git a/pkg/migration/list_test.go b/pkg/migration/list_test.go index 4654fa876a..afe0e61a0c 100644 --- a/pkg/migration/list_test.go +++ b/pkg/migration/list_test.go @@ -89,4 +89,58 @@ func TestLocalMigrations(t *testing.T) { // Check error assert.ErrorContains(t, err, "failed to read directory:") }) + + t.Run("loads folder-based migrations", func(t *testing.T) { + // Setup in-memory fs + fsys := fs.MapFS{ + "20220727064246_test.sql": &fs.MapFile{}, + "20242409125510_premium_mister_fear/schema.sql": &fs.MapFile{}, + "20242409125510_premium_mister_fear/snapshot.json": &fs.MapFile{}, + } + // Run test + versions, err := ListLocalMigrations(".", fsys) + // Check error + assert.NoError(t, err) + assert.Equal(t, []string{ + "20220727064246_test.sql", + "20242409125510_premium_mister_fear/schema.sql", + }, versions) + }) + + t.Run("skips directory without .sql file", func(t *testing.T) { + // Setup in-memory fs — directory with only snapshot.json, no .sql file + fsys := fs.MapFS{ + "20242409125510_premium_mister_fear/snapshot.json": &fs.MapFile{}, + } + // Run test + versions, err := ListLocalMigrations(".", fsys) + // Check error + assert.NoError(t, err) + assert.Empty(t, versions) + }) + + t.Run("skips directory with multiple .sql files", func(t *testing.T) { + // Setup in-memory fs — directory with more than one .sql file + fsys := fs.MapFS{ + "20242409125510_premium_mister_fear/schema.sql": &fs.MapFile{}, + "20242409125510_premium_mister_fear/extra.sql": &fs.MapFile{}, + } + // Run test + versions, err := ListLocalMigrations(".", fsys) + // Check error + assert.NoError(t, err) + assert.Empty(t, versions) + }) + + t.Run("skips directory with non-matching name", func(t *testing.T) { + // Setup in-memory fs — directory name doesn't start with digits + fsys := fs.MapFS{ + "some_random_dir/schema.sql": &fs.MapFile{}, + } + // Run test + versions, err := ListLocalMigrations(".", fsys) + // Check error + assert.NoError(t, err) + assert.Empty(t, versions) + }) }