Skip to content

Commit

Permalink
Allow using Session instead of Engine (#56)
Browse files Browse the repository at this point in the history
  • Loading branch information
qwerty287 authored Oct 26, 2023
1 parent 948eeb3 commit 4ba8b09
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 8 deletions.
59 changes: 51 additions & 8 deletions xormigrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,35 @@ const (
initSchemaMigrationId = "SCHEMA_INIT"
)

// MigrateFunc is the func signature for migratinx.
// MigrateFunc is the func signature for migrating.
type MigrateFunc func(*xorm.Engine) error

// RollbackFunc is the func signature for rollbackinx.
// RollbackFunc is the func signature for rollbacking.
type RollbackFunc func(*xorm.Engine) error

// InitSchemaFunc is the func signature for initializing the schema.
type InitSchemaFunc func(*xorm.Engine) error

// MigrateFunc is the func signature for migrating.
type MigrateFuncSession func(*xorm.Session) error

// RollbackFunc is the func signature for rollbacking.
type RollbackFuncSession func(*xorm.Session) error

// Migration represents a database migration (a modification to be made on the database).
type Migration struct {
// ID is the migration identifier. Usually a timestamp like "201601021504".
ID string `xorm:"id"`
// Description is the migration description, which is optionally printed out when the migration is ran.
Description string
// Migrate is a function that will br executed while running this migration.
// Migrate is a function that will be executed while running this migration.
Migrate MigrateFunc `xorm:"-"`
// Rollback will be executed on rollback. Can be nil.
Rollback RollbackFunc `xorm:"-"`
// MigrateSession is a function that will be executed while running this migration, using xorm.Session.
MigrateSession MigrateFuncSession `xorm:"-"`
// RollbackSession will be executed on rollback, using xorm.Session. Can be nil.
RollbackSession RollbackFuncSession `xorm:"-"`
// Long marks the migration an non-required migration that will likely take a long time. Must use Xormigrate.AllowLong() to be enabled.
Long bool `xorm:"-"`
}
Expand Down Expand Up @@ -251,14 +261,25 @@ func (x *Xormigrate) RollbackMigration(m *Migration) error {
}

func (x *Xormigrate) rollbackMigration(m *Migration) error {
if m.Rollback == nil {
if m.Rollback == nil && m.RollbackSession == nil {
return ErrRollbackImpossible
}
if len(m.Description) > 0 {
logger.Errorf("Rolling back migration: %s", m.Description)
}
if err := m.Rollback(x.db); err != nil {
return err
if m.Rollback != nil {
if err := m.Rollback(x.db); err != nil {
return err
}
} else {
sess := x.db.NewSession()
if err := m.RollbackSession(sess); err != nil {
rollbackSession(sess)
return err
}
if err := sess.Commit(); err != nil {
return err
}
}
if _, err := x.db.In("id", m.ID).Delete(&Migration{}); err != nil {
return err
Expand All @@ -268,7 +289,12 @@ func (x *Xormigrate) rollbackMigration(m *Migration) error {

func (x *Xormigrate) runInitSchema() error {
logger.Info("Initializing Schema")
sess := x.db.NewSession()
if err := x.initSchema(x.db); err != nil {
rollbackSession(sess)
return err
}
if err := sess.Commit(); err != nil {
return err
}
if err := x.insertMigration(initSchemaMigrationId); err != nil {
Expand All @@ -293,8 +319,19 @@ func (x *Xormigrate) runMigration(migration *Migration) error {
if len(migration.Description) > 0 {
logger.Info(migration.Description)
}
if err := migration.Migrate(x.db); err != nil {
return fmt.Errorf("migration %s failed: %s", migration.ID, err.Error())
if migration.Migrate != nil {
if err := migration.Migrate(x.db); err != nil {
return fmt.Errorf("migration %s failed: %s", migration.ID, err.Error())
}
} else {
sess := x.db.NewSession()
if err := migration.MigrateSession(sess); err != nil {
rollbackSession(sess)
return fmt.Errorf("migration %s failed: %s", migration.ID, err.Error())
}
if err := sess.Commit(); err != nil {
return err
}
}

if err := x.insertMigration(migration.ID); err != nil {
Expand Down Expand Up @@ -339,3 +376,9 @@ func (x *Xormigrate) insertMigration(id string) error {
_, err := x.db.Insert(&Migration{ID: id})
return err
}

func rollbackSession(sess *xorm.Session) {
if err := sess.Rollback(); err != nil {
logger.Errorf("Failed to rollback session: %v", err)
}
}
58 changes: 58 additions & 0 deletions xormigrate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,28 @@ var migrations = []*Migration{
},
}

var migrationsSession = []*Migration{
{
ID: "201608301400",
Description: "Add Person",
MigrateSession: func(tx *xorm.Session) error {
return tx.Sync(&Person{})
},
RollbackSession: func(tx *xorm.Session) error {
return tx.DropTable(&Person{})
},
},
{
ID: "201608301430",
MigrateSession: func(tx *xorm.Session) error {
return tx.Sync2(&Pet{})
},
RollbackSession: func(tx *xorm.Session) error {
return tx.DropTable(&Pet{})
},
},
}

var extendedMigrations = append(migrations, &Migration{
ID: "201807221927",
Migrate: func(tx *xorm.Engine) error {
Expand Down Expand Up @@ -355,6 +377,42 @@ func TestAllowLong(t *testing.T) {
})
}

func TestMigrationSession(t *testing.T) {
forEachDatabase(t, func(db *xorm.Engine) {
m := New(db, migrationsSession)

err := m.Migrate()
assert.NoError(t, err)
has, err := db.IsTableExist(&Person{})
assert.NoError(t, err)
assert.True(t, has)
has, err = db.IsTableExist(&Pet{})
assert.NoError(t, err)
assert.True(t, has)
assert.Equal(t, int64(2), tableCount(t, db))

err = m.RollbackLast()
assert.NoError(t, err)
has, err = db.IsTableExist(&Person{})
assert.NoError(t, err)
assert.True(t, has)
has, err = db.Exist(&Pet{})
assert.Error(t, err)
assert.False(t, has)
assert.Equal(t, int64(1), tableCount(t, db))

err = m.RollbackLast()
assert.NoError(t, err)
has, err = db.IsTableExist(&Person{})
assert.NoError(t, err)
assert.False(t, has)
has, err = db.IsTableExist(&Pet{})
assert.NoError(t, err)
assert.False(t, has)
assert.Equal(t, int64(0), tableCount(t, db))
})
}

func tableCount(t *testing.T, db *xorm.Engine) (count int64) {
count, err := db.Count(&Migration{})
assert.NoError(t, err)
Expand Down

0 comments on commit 4ba8b09

Please sign in to comment.