diff --git a/pkg/spanner/migration.go b/pkg/spanner/migration.go index 2472e5e..6950898 100644 --- a/pkg/spanner/migration.go +++ b/pkg/spanner/migration.go @@ -20,6 +20,7 @@ package spanner import ( + "embed" "errors" "fmt" "os" @@ -141,6 +142,68 @@ func LoadMigrations(dir string) (Migrations, error) { return migrations, nil } +func LoadMigrationsFromEmbeddedFS(dir string, scripts embed.FS) (Migrations, error) { + files, err := scripts.ReadDir(dir) + if err != nil { + return nil, err + } + + var migrations Migrations + + versions := map[uint64]string{} + + for _, f := range files { + if f.IsDir() { + continue + } + + filename := f.Name() + + matches := migrationFileRegex.FindStringSubmatch(filename) + if len(matches) != 4 { + continue + } + + version, err := strconv.ParseUint(matches[1], 10, 64) + if err != nil { + continue + } + + file, err := scripts.ReadFile(filepath.Join(dir, filename)) + if err != nil { + continue + } + + statements, err := ddlToStatements(f.Name(), file) + if err != nil { + nstatements, nerr := dmlToStatements(f.Name(), file) + if nerr != nil { + return nil, fmt.Errorf("failed to parse DDL/DML statements: %v, %v", err, nerr) + } + statements = nstatements + } + + kind, err := inspectStatementsKind(statements) + if err != nil { + return nil, err + } + + migrations = append(migrations, &Migration{ + Version: uint(version), + Name: matches[2], + Statements: statements, + kind: kind, + }) + + if prevFileName, ok := versions[version]; ok { + return nil, fmt.Errorf("colliding version number \"%d\" between file names \"%s\" and \"%s\"", version, prevFileName, filename) + } + versions[version] = filename + } + + return migrations, nil +} + func ddlToStatements(filename string, data []byte) ([]string, error) { ddl, err := spansql.ParseDDL(filename, string(data)) if err != nil {