diff --git a/api/config.go b/api/config.go index e5d8053f..5a6f8086 100644 --- a/api/config.go +++ b/api/config.go @@ -46,6 +46,10 @@ type Config struct { SkipMigrationFiles []string MigrationMode MigrationMode + // List of scripts that must run even if their hash hasn't changed. + // Need just the filename without the `functions/` or `views/` prefix. + MustRun []string + // If we are using Kratos auth, some migrations // depend on kratos migrations being ran or not and // can cause problems if mission-control mirations run diff --git a/migrate/migrate.go b/migrate/migrate.go index 931069ea..6e0d71ed 100644 --- a/migrate/migrate.go +++ b/migrate/migrate.go @@ -9,7 +9,6 @@ import ( "path/filepath" "sort" - "github.com/flanksource/commons/collections" "github.com/flanksource/commons/logger" "github.com/flanksource/commons/properties" "github.com/flanksource/duty/api" @@ -17,6 +16,7 @@ import ( "github.com/flanksource/duty/functions" "github.com/flanksource/duty/schema" "github.com/flanksource/duty/views" + "github.com/samber/lo" "github.com/samber/oops" ) @@ -51,13 +51,13 @@ func RunMigrations(pool *sql.DB, config api.Config) error { return fmt.Errorf("failed to create migration log table: %w", err) } - allFunctions, allViews, err := GetExecutableScripts(pool) + allFunctions, allViews, err := GetExecutableScripts(pool, config.MustRun, config.SkipMigrationFiles) if err != nil { return fmt.Errorf("failed to get executable scripts: %w", err) } - l.V(3).Infof("Running scripts") - if err := runScripts(pool, allFunctions, config.SkipMigrationFiles); err != nil { + l.V(3).Infof("Running %d scripts (functions)", len(allFunctions)) + if err := runScripts(pool, allFunctions); err != nil { return fmt.Errorf("failed to run scripts: %w", err) } @@ -72,8 +72,8 @@ func RunMigrations(pool *sql.DB, config api.Config) error { return fmt.Errorf("failed to apply schema migrations: %w", err) } - l.V(3).Infof("Running scripts for views") - if err := runScripts(pool, allViews, config.SkipMigrationFiles); err != nil { + l.V(3).Infof("Running %d scripts (views)", len(allViews)) + if err := runScripts(pool, allViews); err != nil { return fmt.Errorf("failed to run scripts for views: %w", err) } @@ -82,7 +82,7 @@ func RunMigrations(pool *sql.DB, config api.Config) error { // GetExecutableScripts returns functions & views that must be applied. // It takes dependencies into account & excludes any unchanged scripts. -func GetExecutableScripts(pool *sql.DB) (map[string]string, map[string]string, error) { +func GetExecutableScripts(pool *sql.DB, mustRun, skip []string) (map[string]string, map[string]string, error) { l := logger.GetLogger("migrate") var ( @@ -113,14 +113,16 @@ func GetExecutableScripts(pool *sql.DB) (map[string]string, map[string]string, e } for path, content := range funcs { - hash := sha1.Sum([]byte(content)) - if ch, ok := currentMigrationHashes[path]; ok && ch == string(hash[:]) { + if lo.Contains(mustRun, path) { + // proceeed. do not check hash + } else if lo.Contains(skip, path) { + continue + } else if hasMatchingHash(path, content, currentMigrationHashes) { continue } allFunctions[path] = content - // other scripts that depend on this should also be executed for _, dependent := range depGraph[filepath.Join("functions", path)] { baseDir := filepath.Dir(dependent) filename := filepath.Base(dependent) @@ -131,21 +133,22 @@ func GetExecutableScripts(pool *sql.DB) (map[string]string, map[string]string, e case "views": allViews[filename] = views[filename] default: - panic("unhandled base dir") + panic(fmt.Sprintf("unhandled base directory: %s", baseDir)) } } } for path, content := range views { - hash := sha1.Sum([]byte(content)) - if ch, ok := currentMigrationHashes[path]; ok && ch == string(hash[:]) { + if lo.Contains(mustRun, path) { + // proceeed. do not check hash + } else if lo.Contains(skip, path) { + continue + } else if hasMatchingHash(path, content, currentMigrationHashes) { continue } allViews[path] = content - - // other scripts that depend on this should also be executed - for _, dependent := range depGraph[filepath.Join("functions", path)] { + for _, dependent := range depGraph[filepath.Join("views", path)] { baseDir := filepath.Dir(dependent) filename := filepath.Base(dependent) @@ -155,7 +158,7 @@ func GetExecutableScripts(pool *sql.DB) (map[string]string, map[string]string, e case "views": allViews[filename] = views[filename] default: - panic("unhandled base dir") + panic(fmt.Sprintf("unhandled base directory: %s", baseDir)) } } } @@ -265,14 +268,11 @@ func checkIfRoleIsGranted(pool *sql.DB, group, member string) (bool, error) { return true, nil } -func runScripts(pool *sql.DB, scripts map[string]string, ignoreFiles []string) error { +func runScripts(pool *sql.DB, scripts map[string]string) error { l := logger.GetLogger("migrate") var filenames []string for name := range scripts { - if collections.Contains(ignoreFiles, name) { - continue - } filenames = append(filenames, name) } sort.Strings(filenames) @@ -307,3 +307,9 @@ func createMigrationLogTable(pool *sql.DB) error { _, err := pool.Exec(query) return err } + +func hasMatchingHash(path, content string, currentHashes map[string]string) bool { + hash := sha1.Sum([]byte(content)) + currentHash, exists := currentHashes[path] + return exists && currentHash == string(hash[:]) +} diff --git a/tests/migration_dependency_test.go b/tests/migration_dependency_test.go index de510f50..cb7deed7 100644 --- a/tests/migration_dependency_test.go +++ b/tests/migration_dependency_test.go @@ -20,12 +20,42 @@ var _ = Describe("migration dependency", Ordered, func() { db, err := DefaultContext.DB().DB() Expect(err).To(BeNil()) - funcs, views, err := migrate.GetExecutableScripts(db) + funcs, views, err := migrate.GetExecutableScripts(db, nil, nil) Expect(err).To(BeNil()) Expect(len(funcs)).To(BeZero()) Expect(len(views)).To(BeZero()) }) + It("should explicitly run script", func() { + db, err := DefaultContext.DB().DB() + Expect(err).To(BeNil()) + + funcs, views, err := migrate.GetExecutableScripts(db, []string{"incident_ids.sql"}, nil) + Expect(err).To(BeNil()) + Expect(len(funcs)).To(Equal(1)) + Expect(len(views)).To(BeZero()) + }) + + It("should ignore changed hash run script", func() { + var currentHash string + err := DefaultContext.DB().Raw(`SELECT hash FROM migration_logs WHERE path = 'incident_ids.sql'`).Scan(¤tHash).Error + Expect(err).To(BeNil()) + + err = DefaultContext.DB().Exec(`UPDATE migration_logs SET hash = 'dummy' WHERE path = 'incident_ids.sql'`).Error + Expect(err).To(BeNil()) + + db, err := DefaultContext.DB().DB() + Expect(err).To(BeNil()) + + funcs, views, err := migrate.GetExecutableScripts(db, nil, []string{"incident_ids.sql"}) + Expect(err).To(BeNil()) + Expect(len(funcs)).To(BeZero()) + Expect(len(views)).To(BeZero()) + + err = DefaultContext.DB().Exec(`UPDATE migration_logs SET hash = ? WHERE path = 'incident_ids.sql'`, []byte(currentHash)[:]).Error + Expect(err).To(BeNil(), "failed to restore hash for incidents_ids.sql") + }) + It("should get correct executable scripts", func() { err := DefaultContext.DB().Exec(`UPDATE migration_logs SET hash = 'dummy' WHERE path = 'drop.sql'`).Error Expect(err).To(BeNil()) @@ -33,7 +63,7 @@ var _ = Describe("migration dependency", Ordered, func() { sqlDB, err := DefaultContext.DB().DB() Expect(err).To(BeNil()) - funcs, views, err := migrate.GetExecutableScripts(sqlDB) + funcs, views, err := migrate.GetExecutableScripts(sqlDB, nil, nil) Expect(err).To(BeNil()) Expect(len(funcs)).To(Equal(1)) Expect(len(views)).To(Equal(2)) @@ -50,7 +80,7 @@ var _ = Describe("migration dependency", Ordered, func() { db, err := DefaultContext.DB().DB() Expect(err).To(BeNil()) - funcs, views, err := migrate.GetExecutableScripts(db) + funcs, views, err := migrate.GetExecutableScripts(db, nil, nil) Expect(err).To(BeNil()) Expect(len(funcs)).To(BeZero()) Expect(len(views)).To(BeZero())