diff --git a/cmd/squibble/squibble.go b/cmd/squibble/squibble.go index 27194f7..d893c1f 100644 --- a/cmd/squibble/squibble.go +++ b/cmd/squibble/squibble.go @@ -6,6 +6,7 @@ package main import ( "bytes" + "context" "database/sql" "encoding/json" "errors" @@ -141,30 +142,11 @@ var digestFlags struct { } func runDigest(env *command.Env, path string) error { - if digestFlags.SQL || filepath.Ext(path) == ".sql" { - text, err := os.ReadFile(path) - if err != nil { - return err - } - hash, err := squibble.SQLDigest(string(text)) - if err != nil { - return err - } - fmt.Println("sql:", hash) - return nil - } - - db, err := sql.Open("sqlite", path) + kind, digest, err := loadDigest(env.Context(), path, digestFlags.SQL) if err != nil { - return fmt.Errorf("open db: %w", err) + return fmt.Errorf("compute %s digest: %w", kind, err) } - defer db.Close() - - hash, err := squibble.DBDigest(env.Context(), db) - if err != nil { - return err - } - fmt.Println("db: ", hash) + fmt.Printf("%s: %s\n", kind, digest) return nil } @@ -193,3 +175,23 @@ func runHistory(env *command.Env, dbPath string) error { } return nil } + +func loadDigest(ctx context.Context, path string, forceSQL bool) (kind, digest string, _ error) { + if !forceSQL && filepath.Ext(path) != ".sql" { + if _, err := os.Stat(path); err == nil { + db, err := sql.Open("sqlite", path) + if err == nil { + defer db.Close() + d, err := squibble.DBDigest(ctx, db) + return "db", d, err + } + } + // fallthrough + } + text, err := os.ReadFile(path) + if err != nil { + return "sql", "", err + } + d, err := squibble.SQLDigest(string(text)) + return "sql", d, err +}