Skip to content

Commit

Permalink
Remove migrate.FS interface and Migrator.Path field
Browse files Browse the repository at this point in the history
Turns out these are not needed when using `fs.Sub`, `fs.ReadFile`, and `fs.ReadDir` functions.
  • Loading branch information
markuswustenberg committed Jun 16, 2021
1 parent 3465fe8 commit e152045
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 37 deletions.
8 changes: 6 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"context"
"database/sql"
"embed"
"io/fs"

_ "github.com/jackc/pgx/v4/stdlib"
"github.com/maragudk/migrate"
Expand All @@ -26,15 +27,18 @@ import (
// migrations/2.up.sql
// migrations/2.down.sql
//go:embed migrations
var migrations embed.FS
var dir embed.FS

func main() {
db, err := sql.Open("pgx", "postgresql://postgres:123@localhost:5432/postgres?sslmode=disable")
if err != nil {
panic(err)
}
migrations, err := fs.Sub(dir, "migrations")
if err != nil {
panic(err)
}
m := migrate.New(db, migrations)
m.Path = "migrations"
if err := m.MigrateUp(context.Background()); err != nil {
panic(err)
}
Expand Down
22 changes: 7 additions & 15 deletions migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"errors"
"fmt"
"io/fs"
"path"
"regexp"
)

Expand All @@ -16,23 +15,16 @@ var (
downMatcher = regexp.MustCompile(`^([\w]+).down.sql`)
)

type FS interface {
fs.ReadDirFS
fs.ReadFileFS
}

type Migrator struct {
DB *sql.DB
FS FS
Path string
DB *sql.DB
FS fs.FS
}

// New Migrator with default options.
func New(db *sql.DB, fs FS) *Migrator {
func New(db *sql.DB, fs fs.FS) *Migrator {
return &Migrator{
DB: db,
FS: fs,
Path: ".",
DB: db,
FS: fs,
}
}

Expand Down Expand Up @@ -104,7 +96,7 @@ func (m *Migrator) MigrateDown(ctx context.Context) error {

// apply a file identified by name and update to version.
func (m *Migrator) apply(ctx context.Context, name, version string) error {
content, err := m.FS.ReadFile(path.Join(m.Path, name))
content, err := fs.ReadFile(m.FS, name)
if err != nil {
return err
}
Expand All @@ -122,7 +114,7 @@ func (m *Migrator) apply(ctx context.Context, name, version string) error {
// getFilenames alphabetically where the name matches the given matcher.
func (m *Migrator) getFilenames(matcher *regexp.Regexp) ([]string, error) {
var names []string
entries, err := m.FS.ReadDir(m.Path)
entries, err := fs.ReadDir(m.FS, ".")
if err != nil {
return names, err
}
Expand Down
48 changes: 28 additions & 20 deletions migrate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"embed"
"io/fs"
"testing"
"testing/fstest"

Expand All @@ -23,9 +24,8 @@ func TestMigrator_MigrateUp(t *testing.T) {
is := is.New(t)

m := migrate.Migrator{
DB: db,
FS: fstest.MapFS{},
Path: ".",
DB: db,
FS: fstest.MapFS{},
}

err := m.MigrateUp(context.Background())
Expand All @@ -43,9 +43,8 @@ func TestMigrator_MigrateUp(t *testing.T) {
is := is.New(t)

m := migrate.Migrator{
DB: db,
FS: testdata,
Path: "testdata/two",
DB: db,
FS: mustSub(t, testdata, "testdata/two"),
}

err := m.MigrateUp(context.Background())
Expand All @@ -63,9 +62,8 @@ func TestMigrator_MigrateUp(t *testing.T) {
is := is.New(t)

m := migrate.Migrator{
DB: db,
FS: testdata,
Path: "testdata/two",
DB: db,
FS: mustSub(t, testdata, "testdata/two"),
}

err := m.MigrateUp(context.Background())
Expand All @@ -81,9 +79,8 @@ func TestMigrator_MigrateUp(t *testing.T) {
is := is.New(t)

m := migrate.Migrator{
DB: db,
FS: testdata,
Path: "testdata/bad",
DB: db,
FS: mustSub(t, testdata, "testdata/bad"),
}

err := m.MigrateUp(context.Background())
Expand All @@ -104,9 +101,8 @@ func TestMigrator_MigrateDown(t *testing.T) {
is := is.New(t)

m := migrate.Migrator{
DB: db,
FS: testdata,
Path: "testdata/two",
DB: db,
FS: mustSub(t, testdata, "testdata/two"),
}

err := m.MigrateUp(context.Background())
Expand All @@ -132,9 +128,8 @@ func TestMigrator_MigrateDown(t *testing.T) {
is := is.New(t)

m := migrate.Migrator{
DB: db,
FS: testdata,
Path: "testdata/two",
DB: db,
FS: mustSub(t, testdata, "testdata/two"),
}

err := m.MigrateDown(context.Background())
Expand All @@ -150,8 +145,11 @@ func Example() {
if err != nil {
panic(err)
}
m := migrate.New(db, exampleFS)
m.Path = "testdata/example"
migrations, err := fs.Sub(exampleFS, "testdata/example")
if err != nil {
panic(err)
}
m := migrate.New(db, migrations)
if err := m.MigrateUp(context.Background()); err != nil {
panic(err)
}
Expand All @@ -175,3 +173,13 @@ func createDatabase(t *testing.T) (*sql.DB, func()) {
}
}
}

func mustSub(t *testing.T, fsys fs.FS, path string) fs.FS {
t.Helper()
fsys, err := fs.Sub(fsys, path)
if err != nil {
t.Log(err)
t.FailNow()
}
return fsys
}

0 comments on commit e152045

Please sign in to comment.