From c9d26f22702192fe052dc928d3ab45d5fb5a83d9 Mon Sep 17 00:00:00 2001
From: Daisuke Maki <lestrrat+github@gmail.com>
Date: Thu, 21 Sep 2017 13:47:57 +0900
Subject: [PATCH 1/6] Use context

---
 cmd/git-schemalex/main.go | 22 +++++++++++++++++-
 gitschemalex.go           | 49 ++++++++++++++++++---------------------
 gitschemalex_test.go      |  8 +++----
 query.go                  |  9 +++----
 4 files changed, 53 insertions(+), 35 deletions(-)

diff --git a/cmd/git-schemalex/main.go b/cmd/git-schemalex/main.go
index 49ba239..3965822 100644
--- a/cmd/git-schemalex/main.go
+++ b/cmd/git-schemalex/main.go
@@ -1,9 +1,13 @@
 package main
 
 import (
+	"context"
 	"flag"
 	"fmt"
 	"log"
+	"os"
+	"os/signal"
+	"syscall"
 
 	"github.com/schemalex/git-schemalex"
 )
@@ -24,6 +28,22 @@ func main() {
 }
 
 func _main() error {
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+
+	sigCh := make(chan os.Signal, 1)
+	signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGINT, syscall.SIGQUIT)
+
+	go func() {
+		select {
+		case <-ctx.Done():
+			return
+		case <-sigCh:
+			cancel()
+			return
+		}
+	}()
+
 	r := &gitschemalex.Runner{
 		Workspace: *workspace,
 		Deploy:    *deploy,
@@ -31,7 +51,7 @@ func _main() error {
 		Table:     *table,
 		Schema:    *schema,
 	}
-	err := r.Run()
+	err := r.Run(ctx)
 	if err == gitschemalex.ErrEqualVersion {
 		fmt.Println(err.Error())
 		return nil
diff --git a/gitschemalex.go b/gitschemalex.go
index a0b4c3b..0e7d55e 100644
--- a/gitschemalex.go
+++ b/gitschemalex.go
@@ -2,6 +2,7 @@ package gitschemalex
 
 import (
 	"bytes"
+	"context"
 	"database/sql"
 	"errors"
 	"fmt"
@@ -28,7 +29,7 @@ type Runner struct {
 	Schema    string
 }
 
-func (r *Runner) Run() error {
+func (r *Runner) Run(ctx context.Context) error {
 	db, err := r.DB()
 
 	if err != nil {
@@ -37,25 +38,24 @@ func (r *Runner) Run() error {
 
 	defer db.Close()
 
-	schemaVersion, err := r.SchemaVersion()
+	schemaVersion, err := r.SchemaVersion(ctx)
 	if err != nil {
 		return err
 	}
 
-	dbVersion, err := r.DatabaseVersion(db)
-
-	if err != nil {
+	var dbVersion string
+	if err := r.DatabaseVersion(ctx, db, &dbVersion); err != nil {
 		if !strings.Contains(err.Error(), "doesn't exist") {
 			return err
 		}
-		return r.DeploySchema(db, schemaVersion)
+		return r.DeploySchema(ctx, db, schemaVersion)
 	}
 
 	if dbVersion == schemaVersion {
 		return ErrEqualVersion
 	}
 
-	if err := r.UpgradeSchema(db, schemaVersion, dbVersion); err != nil {
+	if err := r.UpgradeSchema(ctx, db, schemaVersion, dbVersion); err != nil {
 		return err
 	}
 
@@ -66,14 +66,12 @@ func (r *Runner) DB() (*sql.DB, error) {
 	return sql.Open("mysql", r.DSN)
 }
 
-func (r *Runner) DatabaseVersion(db *sql.DB) (version string, err error) {
-	err = db.QueryRow(fmt.Sprintf("SELECT version FROM `%s`", r.Table)).Scan(&version)
-	return
+func (r *Runner) DatabaseVersion(ctx context.Context, db *sql.DB, version *string) error {
+	return db.QueryRowContext(ctx, fmt.Sprintf("SELECT version FROM `%s`", r.Table)).Scan(version)
 }
 
-func (r *Runner) SchemaVersion() (string, error) {
-
-	byt, err := r.execGitCmd("log", "-n", "1", "--pretty=format:%H", "--", r.Schema)
+func (r *Runner) SchemaVersion(ctx context.Context) (string, error) {
+	byt, err := r.execGitCmd(ctx, "log", "-n", "1", "--pretty=format:%H", "--", r.Schema)
 	if err != nil {
 		return "", err
 	}
@@ -81,7 +79,7 @@ func (r *Runner) SchemaVersion() (string, error) {
 	return string(byt), nil
 }
 
-func (r *Runner) DeploySchema(db *sql.DB, version string) error {
+func (r *Runner) DeploySchema(ctx context.Context, db *sql.DB, version string) error {
 	content, err := r.schemaContent()
 	if err != nil {
 		return err
@@ -89,12 +87,11 @@ func (r *Runner) DeploySchema(db *sql.DB, version string) error {
 	queries := queryListFromString(content)
 	queries.AppendStmt(fmt.Sprintf("CREATE TABLE `%s` ( version VARCHAR(40) NOT NULL )", r.Table))
 	queries.AppendStmt(fmt.Sprintf("INSERT INTO `%s` (version) VALUES (?)", r.Table), version)
-	return r.execSql(db, queries)
+	return r.execSql(ctx, db, queries)
 }
 
-func (r *Runner) UpgradeSchema(db *sql.DB, schemaVersion string, dbVersion string) error {
-
-	lastSchema, err := r.schemaSpecificCommit(dbVersion)
+func (r *Runner) UpgradeSchema(ctx context.Context, db *sql.DB, schemaVersion string, dbVersion string) error {
+	lastSchema, err := r.schemaSpecificCommit(ctx, dbVersion)
 	if err != nil {
 		return err
 	}
@@ -113,13 +110,13 @@ func (r *Runner) UpgradeSchema(db *sql.DB, schemaVersion string, dbVersion strin
 	queries := queryListFromString(stmts.String())
 	queries.AppendStmt(fmt.Sprintf("UPDATE %s SET version = ?", r.Table), schemaVersion)
 
-	return r.execSql(db, queries)
+	return r.execSql(ctx, db, queries)
 }
 
 // private
 
-func (r *Runner) schemaSpecificCommit(commit string) (string, error) {
-	byt, err := r.execGitCmd("ls-tree", commit, "--", r.Schema)
+func (r *Runner) schemaSpecificCommit(ctx context.Context, commit string) (string, error) {
+	byt, err := r.execGitCmd(ctx, "ls-tree", commit, "--", r.Schema)
 
 	if err != nil {
 		return "", err
@@ -127,7 +124,7 @@ func (r *Runner) schemaSpecificCommit(commit string) (string, error) {
 
 	fields := strings.Fields(string(byt))
 
-	byt, err = r.execGitCmd("cat-file", "blob", fields[2])
+	byt, err = r.execGitCmd(ctx, "cat-file", "blob", fields[2])
 	if err != nil {
 		return "", err
 	}
@@ -135,11 +132,11 @@ func (r *Runner) schemaSpecificCommit(commit string) (string, error) {
 	return string(byt), nil
 }
 
-func (r *Runner) execSql(db *sql.DB, queries queryList) error {
+func (r *Runner) execSql(ctx context.Context, db *sql.DB, queries queryList) error {
 	if !r.Deploy {
 		return queries.dump(os.Stdout)
 	}
-	return queries.execute(db)
+	return queries.execute(ctx, db)
 }
 
 func (r *Runner) schemaContent() (string, error) {
@@ -150,8 +147,8 @@ func (r *Runner) schemaContent() (string, error) {
 	return string(byt), nil
 }
 
-func (r *Runner) execGitCmd(args ...string) ([]byte, error) {
-	cmd := exec.Command("git", args...)
+func (r *Runner) execGitCmd(ctx context.Context, args ...string) ([]byte, error) {
+	cmd := exec.CommandContext(ctx, "git", args...)
 	if r.Workspace != "" {
 		cmd.Dir = r.Workspace
 	}
diff --git a/gitschemalex_test.go b/gitschemalex_test.go
index 8059b93..267d42d 100644
--- a/gitschemalex_test.go
+++ b/gitschemalex_test.go
@@ -1,6 +1,7 @@
 package gitschemalex
 
 import (
+	"context"
 	"database/sql"
 	"io/ioutil"
 	"os"
@@ -91,7 +92,7 @@ func TestRunner(t *testing.T) {
 		Table:     "git_schemalex_version",
 		Schema:    "schema.sql",
 	}
-	if err := r.Run(); err != nil {
+	if err := r.Run(context.TODO()); err != nil {
 		t.Fatal(err)
 	}
 
@@ -113,7 +114,7 @@ func TestRunner(t *testing.T) {
 		t.Fatal(err)
 	}
 
-	if err := r.Run(); err != nil {
+	if err := r.Run(context.TODO()); err != nil {
 		t.Fatal(err)
 	}
 
@@ -122,8 +123,7 @@ func TestRunner(t *testing.T) {
 	}
 
 	// equal version
-
-	if e, g := ErrEqualVersion, r.Run(); e != g {
+	if e, g := ErrEqualVersion, r.Run(context.TODO()); e != g {
 		t.Fatal("should %v got %v", e, g)
 	}
 }
diff --git a/query.go b/query.go
index 0c5c26a..b3a74db 100644
--- a/query.go
+++ b/query.go
@@ -1,6 +1,7 @@
 package gitschemalex
 
 import (
+	"context"
 	"database/sql"
 	"fmt"
 	"io"
@@ -14,8 +15,8 @@ type query struct {
 	args []interface{}
 }
 
-func (q *query) execute(db *sql.DB) error {
-	_, err := db.Exec(q.stmt, q.args...)
+func (q *query) execute(ctx context.Context, db *sql.DB) error {
+	_, err := db.ExecContext(ctx, q.stmt, q.args...)
 	return errors.Wrap(err, `failed to execute query`)
 }
 
@@ -58,9 +59,9 @@ func (l *queryList) dump(dst io.Writer) error {
 	return nil
 }
 
-func (l *queryList) execute(db *sql.DB) error {
+func (l *queryList) execute(ctx context.Context, db *sql.DB) error {
 	for i, q := range *l {
-		if err := q.execute(db); err != nil {
+		if err := q.execute(ctx, db); err != nil {
 			return errors.Wrapf(err, `failed to execute query %d`, i+1)
 		}
 	}

From 4a2aae3e30913e9f5b626ed9800703f3506179a3 Mon Sep 17 00:00:00 2001
From: Daisuke Maki <lestrrat+github@gmail.com>
Date: Thu, 21 Sep 2017 13:49:11 +0900
Subject: [PATCH 2/6] Why use this tricky bit of code??

---
 gitschemalex_test.go | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/gitschemalex_test.go b/gitschemalex_test.go
index 267d42d..6bee871 100644
--- a/gitschemalex_test.go
+++ b/gitschemalex_test.go
@@ -123,7 +123,7 @@ func TestRunner(t *testing.T) {
 	}
 
 	// equal version
-	if e, g := ErrEqualVersion, r.Run(context.TODO()); e != g {
-		t.Fatal("should %v got %v", e, g)
+	if err := r.Run(context.TODO()); err != ErrEqualVersion {
+		t.Fatal("should %v got %v", err, ErrEqualVersion)
 	}
 }

From 00642c2af767d33f442b9ad3a96d08a20311abfa Mon Sep 17 00:00:00 2001
From: Daisuke Maki <lestrrat+github@gmail.com>
Date: Thu, 21 Sep 2017 14:22:15 +0900
Subject: [PATCH 3/6] add -commit

---
 cmd/git-schemalex/main.go | 16 +++++----
 gitschemalex.go           | 74 +++++++++++++++++++--------------------
 gitschemalex_test.go      | 14 ++++----
 3 files changed, 52 insertions(+), 52 deletions(-)

diff --git a/cmd/git-schemalex/main.go b/cmd/git-schemalex/main.go
index 3965822..fba8b9e 100644
--- a/cmd/git-schemalex/main.go
+++ b/cmd/git-schemalex/main.go
@@ -14,6 +14,7 @@ import (
 
 var (
 	workspace = flag.String("workspace", "", "workspace of git")
+	commit    = flag.String("commit", "HEAD", "target git commit hash")
 	deploy    = flag.Bool("deploy", false, "deploy")
 	dsn       = flag.String("dsn", "", "")
 	table     = flag.String("table", "git_schemalex_version", "table of git revision")
@@ -44,13 +45,14 @@ func _main() error {
 		}
 	}()
 
-	r := &gitschemalex.Runner{
-		Workspace: *workspace,
-		Deploy:    *deploy,
-		DSN:       *dsn,
-		Table:     *table,
-		Schema:    *schema,
-	}
+	r := gitschemalex.New()
+	r.Workspace = *workspace
+	r.Commit = *commit
+	r.Deploy = *deploy
+	r.DSN = *dsn
+	r.Table = *table
+	r.Schema = *schema
+
 	err := r.Run(ctx)
 	if err == gitschemalex.ErrEqualVersion {
 		fmt.Println(err.Error())
diff --git a/gitschemalex.go b/gitschemalex.go
index 0e7d55e..2cd0478 100644
--- a/gitschemalex.go
+++ b/gitschemalex.go
@@ -6,10 +6,8 @@ import (
 	"database/sql"
 	"errors"
 	"fmt"
-	"io/ioutil"
 	"os"
 	"os/exec"
-	"path/filepath"
 	"strings"
 
 	_ "github.com/go-sql-driver/mysql"
@@ -23,23 +21,29 @@ var (
 
 type Runner struct {
 	Workspace string
+	Commit    string
 	Deploy    bool
 	DSN       string
 	Table     string
 	Schema    string
 }
 
+func New() *Runner {
+	return &Runner{
+		Commit: "HEAD",
+		Deploy: false,
+	}
+}
+
 func (r *Runner) Run(ctx context.Context) error {
 	db, err := r.DB()
-
 	if err != nil {
 		return err
 	}
-
 	defer db.Close()
 
-	schemaVersion, err := r.SchemaVersion(ctx)
-	if err != nil {
+	var schemaVersion string
+	if err := r.SchemaVersion(ctx, &schemaVersion); err != nil {
 		return err
 	}
 
@@ -70,20 +74,24 @@ func (r *Runner) DatabaseVersion(ctx context.Context, db *sql.DB, version *strin
 	return db.QueryRowContext(ctx, fmt.Sprintf("SELECT version FROM `%s`", r.Table)).Scan(version)
 }
 
-func (r *Runner) SchemaVersion(ctx context.Context) (string, error) {
-	byt, err := r.execGitCmd(ctx, "log", "-n", "1", "--pretty=format:%H", "--", r.Schema)
+func (r *Runner) SchemaVersion(ctx context.Context, version *string) error {
+	// git rev-parse takes things like "HEAD" or commit hash, and gives
+	// us the corresponding commit hash
+	v, err := r.execGitCmd(ctx, "rev-parse", r.Commit)
 	if err != nil {
-		return "", err
+		return err
 	}
 
-	return string(byt), nil
+	*version = string(v)
+	return nil
 }
 
 func (r *Runner) DeploySchema(ctx context.Context, db *sql.DB, version string) error {
-	content, err := r.schemaContent()
-	if err != nil {
+	var content string
+	if err := r.schemaSpecificCommit(ctx, version, &content); err != nil {
 		return err
 	}
+
 	queries := queryListFromString(content)
 	queries.AppendStmt(fmt.Sprintf("CREATE TABLE `%s` ( version VARCHAR(40) NOT NULL )", r.Table))
 	queries.AppendStmt(fmt.Sprintf("INSERT INTO `%s` (version) VALUES (?)", r.Table), version)
@@ -91,19 +99,18 @@ func (r *Runner) DeploySchema(ctx context.Context, db *sql.DB, version string) e
 }
 
 func (r *Runner) UpgradeSchema(ctx context.Context, db *sql.DB, schemaVersion string, dbVersion string) error {
-	lastSchema, err := r.schemaSpecificCommit(ctx, dbVersion)
-	if err != nil {
+	var lastSchema string
+	if err := r.schemaSpecificCommit(ctx, dbVersion, &lastSchema); err != nil {
 		return err
 	}
 
-	currentSchema, err := r.schemaContent()
-	if err != nil {
+	var currentSchema string
+	if err := r.schemaSpecificCommit(ctx, schemaVersion, &currentSchema); err != nil {
 		return err
 	}
 	stmts := &bytes.Buffer{}
 	p := schemalex.New()
-	err = diff.Strings(stmts, lastSchema, currentSchema, diff.WithTransaction(true), diff.WithParser(p))
-	if err != nil {
+	if err := diff.Strings(stmts, lastSchema, currentSchema, diff.WithTransaction(true), diff.WithParser(p)); err != nil {
 		return err
 	}
 
@@ -115,21 +122,20 @@ func (r *Runner) UpgradeSchema(ctx context.Context, db *sql.DB, schemaVersion st
 
 // private
 
-func (r *Runner) schemaSpecificCommit(ctx context.Context, commit string) (string, error) {
-	byt, err := r.execGitCmd(ctx, "ls-tree", commit, "--", r.Schema)
-
-	if err != nil {
-		return "", err
-	}
-
-	fields := strings.Fields(string(byt))
-
-	byt, err = r.execGitCmd(ctx, "cat-file", "blob", fields[2])
+func (r *Runner) schemaSpecificCommit(ctx context.Context, commit string, dst *string) error {
+	// Old code used to do ls-tree and then cat-file, but I don't see why
+	// you need to do this.
+	// Doing
+	// > fields := git ls-tree $commit -- $schema_file
+	// And then taking fields[2] just gives us back $commit.
+	// showing the contents at the point of commit using "git show" is much simpler
+	v, err := r.execGitCmd(ctx, "show", fmt.Sprintf("%s:%s", commit, r.Schema))
 	if err != nil {
-		return "", err
+		return err
 	}
 
-	return string(byt), nil
+	*dst = string(v)
+	return nil
 }
 
 func (r *Runner) execSql(ctx context.Context, db *sql.DB, queries queryList) error {
@@ -139,14 +145,6 @@ func (r *Runner) execSql(ctx context.Context, db *sql.DB, queries queryList) err
 	return queries.execute(ctx, db)
 }
 
-func (r *Runner) schemaContent() (string, error) {
-	byt, err := ioutil.ReadFile(filepath.Join(r.Workspace, r.Schema))
-	if err != nil {
-		return "", err
-	}
-	return string(byt), nil
-}
-
 func (r *Runner) execGitCmd(ctx context.Context, args ...string) ([]byte, error) {
 	cmd := exec.CommandContext(ctx, "git", args...)
 	if r.Workspace != "" {
diff --git a/gitschemalex_test.go b/gitschemalex_test.go
index 6bee871..8ddd242 100644
--- a/gitschemalex_test.go
+++ b/gitschemalex_test.go
@@ -13,6 +13,7 @@ import (
 
 	_ "github.com/go-sql-driver/mysql"
 	"github.com/lestrrat/go-test-mysqld"
+	gitschemalex "github.com/schemalex/git-schemalex"
 )
 
 func TestRunner(t *testing.T) {
@@ -85,13 +86,12 @@ func TestRunner(t *testing.T) {
 	// whatever to "test"
 	re := regexp.MustCompile(`/[^/]+$`)
 	dsn = re.ReplaceAllString(dsn, `/test`)
-	r := &Runner{
-		Workspace: dir,
-		Deploy:    true,
-		DSN:       dsn,
-		Table:     "git_schemalex_version",
-		Schema:    "schema.sql",
-	}
+	r := gitschemalex.New()
+	r.Workspace = dir
+	r.Deploy = true
+	r.DSN = dsn
+	r.Table = "git_schemalex_version"
+	r.Schema = "schema.sql"
 	if err := r.Run(context.TODO()); err != nil {
 		t.Fatal(err)
 	}

From 2003acc0176289add33a510f360a85029a87eb63 Mon Sep 17 00:00:00 2001
From: Daisuke Maki <lestrrat+github@gmail.com>
Date: Thu, 21 Sep 2017 14:29:22 +0900
Subject: [PATCH 4/6] avoid cycle

---
 gitschemalex_test.go | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/gitschemalex_test.go b/gitschemalex_test.go
index 8ddd242..2f862d0 100644
--- a/gitschemalex_test.go
+++ b/gitschemalex_test.go
@@ -13,7 +13,6 @@ import (
 
 	_ "github.com/go-sql-driver/mysql"
 	"github.com/lestrrat/go-test-mysqld"
-	gitschemalex "github.com/schemalex/git-schemalex"
 )
 
 func TestRunner(t *testing.T) {
@@ -86,7 +85,7 @@ func TestRunner(t *testing.T) {
 	// whatever to "test"
 	re := regexp.MustCompile(`/[^/]+$`)
 	dsn = re.ReplaceAllString(dsn, `/test`)
-	r := gitschemalex.New()
+	r := New()
 	r.Workspace = dir
 	r.Deploy = true
 	r.DSN = dsn

From 24c7542f0b1d1c7201a73cb43771745b67da2865 Mon Sep 17 00:00:00 2001
From: Daisuke Maki <lestrrat+github@gmail.com>
Date: Thu, 21 Sep 2017 14:33:09 +0900
Subject: [PATCH 5/6] trim

---
 gitschemalex.go | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/gitschemalex.go b/gitschemalex.go
index 2cd0478..2e8db9e 100644
--- a/gitschemalex.go
+++ b/gitschemalex.go
@@ -82,7 +82,7 @@ func (r *Runner) SchemaVersion(ctx context.Context, version *string) error {
 		return err
 	}
 
-	*version = string(v)
+	*version = string(bytes.TrimSpace(v))
 	return nil
 }
 

From e594a237dc2d5b6664e5339a920fbb7319028cf1 Mon Sep 17 00:00:00 2001
From: Daisuke Maki <lestrrat+github@gmail.com>
Date: Thu, 21 Sep 2017 14:42:29 +0900
Subject: [PATCH 6/6] Add a sane default for -dsn

---
 cmd/git-schemalex/main.go | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/cmd/git-schemalex/main.go b/cmd/git-schemalex/main.go
index fba8b9e..3d6bba1 100644
--- a/cmd/git-schemalex/main.go
+++ b/cmd/git-schemalex/main.go
@@ -16,7 +16,7 @@ var (
 	workspace = flag.String("workspace", "", "workspace of git")
 	commit    = flag.String("commit", "HEAD", "target git commit hash")
 	deploy    = flag.Bool("deploy", false, "deploy")
-	dsn       = flag.String("dsn", "", "")
+	dsn       = flag.String("dsn", "root:@tcp(127.0.0.1:3306)/test", "DSN of the target mysql instance")
 	table     = flag.String("table", "git_schemalex_version", "table of git revision")
 	schema    = flag.String("schema", "", "path to schema file")
 )