From ab5dc522f434ea8ff552d370ec931e68c434d97f Mon Sep 17 00:00:00 2001 From: dabasov Date: Fri, 4 Aug 2023 15:53:50 +0300 Subject: [PATCH] cleaned up transaction logic --- .../0chain.net/blobbercore/allocation/dao.go | 8 +- .../blobbercore/allocation/workers.go | 10 -- .../automigration/automigration.go | 124 ------------------ .../blobbercore/datastore/mocket.go | 16 +++ .../blobbercore/datastore/postgres.go | 15 +++ .../blobbercore/datastore/sqlmock.go | 16 +++ .../0chain.net/blobbercore/datastore/store.go | 1 + .../0chain.net/blobbercore/handler/context.go | 59 +++++---- .../0chain.net/blobbercore/handler/handler.go | 43 +++--- .../handler/handler_hashnode_test.go | 2 +- .../handler/handler_playlist_test.go | 4 +- .../handler/handler_writemarker_test.go | 4 +- .../blobbercore/reference/referencepath.go | 37 ++++-- 13 files changed, 122 insertions(+), 217 deletions(-) delete mode 100644 code/go/0chain.net/blobbercore/automigration/automigration.go diff --git a/code/go/0chain.net/blobbercore/allocation/dao.go b/code/go/0chain.net/blobbercore/allocation/dao.go index 68e6dc10d..a62005e4f 100644 --- a/code/go/0chain.net/blobbercore/allocation/dao.go +++ b/code/go/0chain.net/blobbercore/allocation/dao.go @@ -3,7 +3,6 @@ package allocation import ( "context" - "github.com/0chain/blobber/code/go/0chain.net/blobbercore/datastore" "github.com/0chain/blobber/code/go/0chain.net/core/common" "github.com/0chain/errors" "github.com/0chain/gosdk/constants" @@ -12,16 +11,11 @@ import ( // GetOrCreate, get allocation if it exists in db. if not, try to sync it from blockchain, and insert it in db. func GetOrCreate(ctx context.Context, allocationId string) (*Allocation, error) { - - db := datastore.GetStore().CreateTransaction(ctx) - tx := datastore.GetStore().GetTransaction(ctx) - if len(allocationId) == 0 { return nil, errors.Throw(constants.ErrInvalidParameter, "tx") } - alloc, err := Repo.GetById(db, allocationId) - tx.Rollback() + alloc, err := Repo.GetById(ctx, allocationId) if err == nil { return alloc, nil diff --git a/code/go/0chain.net/blobbercore/allocation/workers.go b/code/go/0chain.net/blobbercore/allocation/workers.go index c2ac4fee6..44306b911 100644 --- a/code/go/0chain.net/blobbercore/allocation/workers.go +++ b/code/go/0chain.net/blobbercore/allocation/workers.go @@ -13,8 +13,6 @@ import ( "github.com/0chain/blobber/code/go/0chain.net/core/logging" "github.com/0chain/blobber/code/go/0chain.net/core/transaction" - "gorm.io/gorm" - "go.uber.org/zap" ) @@ -169,14 +167,6 @@ func requestAllocation(allocID string) (sa *transaction.StorageAllocation, err e return } -func commit(tx *gorm.DB, err *error) { - if (*err) != nil { - tx.Rollback() - return - } - (*err) = tx.Commit().Error -} - func updateAllocationInDB(ctx context.Context, a *Allocation, sa *transaction.StorageAllocation) (ua *Allocation, err error) { var tx = datastore.GetStore().GetTransaction(ctx) var changed bool = a.Tx != sa.Tx diff --git a/code/go/0chain.net/blobbercore/automigration/automigration.go b/code/go/0chain.net/blobbercore/automigration/automigration.go deleted file mode 100644 index 7e76a62b0..000000000 --- a/code/go/0chain.net/blobbercore/automigration/automigration.go +++ /dev/null @@ -1,124 +0,0 @@ -//This file is used to create table schemas using gorm's automigration feature which takes information from -//struct's fields and functions - -package automigration - -import ( - "fmt" - - "github.com/0chain/blobber/code/go/0chain.net/blobbercore/allocation" - "github.com/0chain/blobber/code/go/0chain.net/blobbercore/challenge" - "github.com/0chain/blobber/code/go/0chain.net/blobbercore/config" - "github.com/0chain/blobber/code/go/0chain.net/blobbercore/readmarker" - "github.com/0chain/blobber/code/go/0chain.net/blobbercore/reference" - "github.com/0chain/blobber/code/go/0chain.net/blobbercore/writemarker" - "gorm.io/gorm" -) - -type tableNameI interface { - TableName() string -} - -var tableModels = []tableNameI{ - new(reference.Ref), - new(reference.ShareInfo), - new(challenge.ChallengeEntity), - new(challenge.ChallengeTiming), - new(allocation.Allocation), - new(allocation.AllocationChange), - new(allocation.AllocationChangeCollector), - new(allocation.Pending), - new(allocation.Terms), - new(allocation.ReadPool), - new(allocation.WritePool), - new(readmarker.ReadMarkerEntity), - new(writemarker.WriteMarkerEntity), - new(writemarker.WriteLock), - new(reference.FileStats), - new(config.Settings), -} - -func createDB(db *gorm.DB) (err error) { - // check if db exists - dbstmt := fmt.Sprintf("SELECT datname, oid FROM pg_database WHERE datname = '%s';", config.Configuration.DBName) - rs := db.Raw(dbstmt) - if rs.Error != nil { - return rs.Error - } - - var result struct { - Datname string - } - - if rs.Scan(&result); len(result.Datname) == 0 { - stmt := fmt.Sprintf("CREATE DATABASE %s;", config.Configuration.DBName) - if rs := db.Exec(stmt); rs.Error != nil { - return rs.Error - } - if rs := db.Exec("CREATE EXTENSION IF NOT EXISTS pg_trgm;"); rs.Error != nil { - return rs.Error - } - } - return -} - -func createUser(db *gorm.DB) error { - usrstmt := fmt.Sprintf("SELECT usename, usesysid FROM pg_catalog.pg_user WHERE usename = '%s';", config.Configuration.DBUserName) - rs := db.Raw(usrstmt) - if rs.Error != nil { - return rs.Error - } - - var result struct { - Usename string - } - - if rs.Scan(&result); len(result.Usename) == 0 { - stmt := fmt.Sprintf("CREATE USER %s WITH ENCRYPTED PASSWORD '%s';", config.Configuration.DBUserName, config.Configuration.DBPassword) - if rs := db.Exec(stmt); rs.Error != nil && rs.Error.Error() != fmt.Sprintf("pq: role \"%s\" already exists", config.Configuration.DBUserName) { - return rs.Error - } - } - return nil -} - -func grantPrivileges(db *gorm.DB) error { - stmts := []string{ - fmt.Sprintf("GRANT ALL PRIVILEGES ON DATABASE %s TO %s;", config.Configuration.DBName, config.Configuration.DBUserName), - fmt.Sprintf("GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO %s;", config.Configuration.DBUserName), - fmt.Sprintf("GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public TO %s;", config.Configuration.DBUserName), - } - for _, stmt := range stmts { - err := db.Exec(stmt).Error - if err != nil { - return err - } - } - return nil -} - -func MigrateSchema(db *gorm.DB) error { - var tables []interface{} // Put in new slice to resolve type mismatch - for _, tbl := range tableModels { - tables = append(tables, tbl) - } - - if err := db.AutoMigrate(tables...); err != nil { - return err - } - err := db.Exec(`ALTER TABLE reference_objects ALTER COLUMN path TYPE varchar(1000) COLLATE "POSIX"`).Error - if err != nil { - return err - } - return nil -} - -// DropSchemas is used for integration tests to clear DB. -func DropSchemas(db *gorm.DB) error { - var tables []interface{} // Put in new slice to resolve type mismatch - for _, tbl := range tableModels { - tables = append(tables, tbl) - } - - return db.Migrator().DropTable(tables...) -} diff --git a/code/go/0chain.net/blobbercore/datastore/mocket.go b/code/go/0chain.net/blobbercore/datastore/mocket.go index 56d25151e..e5f0ef47d 100644 --- a/code/go/0chain.net/blobbercore/datastore/mocket.go +++ b/code/go/0chain.net/blobbercore/datastore/mocket.go @@ -99,6 +99,22 @@ func (store *Mocket) WithNewTransaction(f func(ctx context.Context) error) error return nil } +func (store *Mocket) WithTransaction(ctx context.Context, f func(ctx context.Context) error) error { + tx := store.GetTransaction(ctx) + if tx == nil { + ctx = store.CreateTransaction(ctx) + tx = store.GetTransaction(ctx) + } + + err := f(ctx) + if err != nil { + tx.Rollback() + return err + } + tx.Commit() + return nil +} + func (store *Mocket) GetDB() *gorm.DB { return store.db } diff --git a/code/go/0chain.net/blobbercore/datastore/postgres.go b/code/go/0chain.net/blobbercore/datastore/postgres.go index 48668c404..2bfcde8e3 100644 --- a/code/go/0chain.net/blobbercore/datastore/postgres.go +++ b/code/go/0chain.net/blobbercore/datastore/postgres.go @@ -116,6 +116,21 @@ func (store *postgresStore) WithNewTransaction(f func(ctx context.Context) error tx.Commit() return nil } +func (store *postgresStore) WithTransaction(ctx context.Context, f func(ctx context.Context) error) error { + tx := store.GetTransaction(ctx) + if tx == nil { + ctx = store.CreateTransaction(ctx) + tx = store.GetTransaction(ctx) + } + + err := f(ctx) + if err != nil { + tx.Rollback() + return err + } + tx.Commit() + return nil +} func (store *postgresStore) GetDB() *gorm.DB { return store.db diff --git a/code/go/0chain.net/blobbercore/datastore/sqlmock.go b/code/go/0chain.net/blobbercore/datastore/sqlmock.go index 25096d114..b528ec73b 100644 --- a/code/go/0chain.net/blobbercore/datastore/sqlmock.go +++ b/code/go/0chain.net/blobbercore/datastore/sqlmock.go @@ -94,6 +94,22 @@ func (store *Sqlmock) WithNewTransaction(f func(ctx context.Context) error) erro return nil } +func (store *Sqlmock) WithTransaction(ctx context.Context, f func(ctx context.Context) error) error { + tx := store.GetTransaction(ctx) + if tx == nil { + ctx = store.CreateTransaction(ctx) + tx = store.GetTransaction(ctx) + } + + err := f(ctx) + if err != nil { + tx.Rollback() + return err + } + tx.Commit() + return nil +} + func (store *Sqlmock) GetDB() *gorm.DB { return store.db } diff --git a/code/go/0chain.net/blobbercore/datastore/store.go b/code/go/0chain.net/blobbercore/datastore/store.go index 384911861..866ace2d4 100644 --- a/code/go/0chain.net/blobbercore/datastore/store.go +++ b/code/go/0chain.net/blobbercore/datastore/store.go @@ -32,6 +32,7 @@ type Store interface { // GetTransaction get transaction from context GetTransaction(ctx context.Context) *EnhancedDB WithNewTransaction(f func(ctx context.Context) error) error + WithTransaction(ctx context.Context, f func(ctx context.Context) error) error // Get db connection with user that creates roles and databases. Its dialactor does not contain database name GetPgDB() (*gorm.DB, error) Open() error diff --git a/code/go/0chain.net/blobbercore/handler/context.go b/code/go/0chain.net/blobbercore/handler/context.go index 43c7a4bdf..ead3df7b8 100644 --- a/code/go/0chain.net/blobbercore/handler/context.go +++ b/code/go/0chain.net/blobbercore/handler/context.go @@ -130,8 +130,8 @@ type ErrorResponse struct { Error string } -// WithHandler process handler to respond request -func WithHandler(handler func(ctx *Context) (interface{}, error)) func(w http.ResponseWriter, r *http.Request) { +// WithTxHandler process handler to respond request +func WithTxHandler(handler func(ctx *Context) (interface{}, error)) func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Access-Control-Allow-Origin", "*") // CORS for all. w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE") @@ -143,23 +143,39 @@ func WithHandler(handler func(ctx *Context) (interface{}, error)) func(w http.Re } common.TryParseForm(r) - w.Header().Set("Content-Type", "application/json") - ctx, err := WithVerify(r) - statusCode := ctx.StatusCode + statusCode := 0 + var result interface{} + err := datastore.GetStore().WithNewTransaction(func(c context.Context) error { + ctx := &Context{ + Context: c, + Request: r, + Store: datastore.GetStore(), + } - if err != nil { - if statusCode == 0 { - statusCode = http.StatusInternalServerError + ctx.Vars = mux.Vars(r) + if ctx.Vars == nil { + ctx.Vars = make(map[string]string) } - http.Error(w, err.Error(), statusCode) - return - } + ctx.ClientID = r.Header.Get(common.ClientHeader) + ctx.ClientKey = r.Header.Get(common.ClientKeyHeader) + ctx.AllocationId = r.Header.Get(common.AllocationIdHeader) + ctx.Signature = r.Header.Get(common.ClientSignatureHeader) - result, err := handler(ctx) - statusCode = ctx.StatusCode + ctx, err := WithVerify(ctx, r) + statusCode = ctx.StatusCode + + if err != nil { + return err + } + + result, err = handler(ctx) + statusCode = ctx.StatusCode + + return nil + }) if err != nil { if statusCode == 0 { @@ -183,23 +199,8 @@ func WithHandler(handler func(ctx *Context) (interface{}, error)) func(w http.Re } // WithVerify verify allocation and signature -func WithVerify(r *http.Request) (*Context, error) { - - ctx := &Context{ - Context: context.TODO(), - Request: r, - Store: datastore.GetStore(), - } - - ctx.Vars = mux.Vars(r) - if ctx.Vars == nil { - ctx.Vars = make(map[string]string) - } +func WithVerify(ctx *Context, r *http.Request) (*Context, error) { - ctx.ClientID = r.Header.Get(common.ClientHeader) - ctx.ClientKey = r.Header.Get(common.ClientKeyHeader) - ctx.AllocationId = r.Header.Get(common.AllocationIdHeader) - ctx.Signature = r.Header.Get(common.ClientSignatureHeader) allocationTx := ctx.Vars["allocation"] if len(ctx.AllocationId) > 0 { diff --git a/code/go/0chain.net/blobbercore/handler/handler.go b/code/go/0chain.net/blobbercore/handler/handler.go index 4ec30595a..7ad981eec 100644 --- a/code/go/0chain.net/blobbercore/handler/handler.go +++ b/code/go/0chain.net/blobbercore/handler/handler.go @@ -248,23 +248,23 @@ func setupHandlers(r *mux.Router) { // lightweight http handler without heavy postgres transaction to improve performance r.HandleFunc("/v1/writemarker/lock/{allocation}", - RateLimitByGeneralRL(WithHandler(LockWriteMarker))). + RateLimitByGeneralRL(WithTxHandler(LockWriteMarker))). Methods(http.MethodPost, http.MethodOptions) r.HandleFunc("/v1/writemarker/lock/{allocation}/{connection}", - RateLimitByGeneralRL(WithHandler(UnlockWriteMarker))). + RateLimitByGeneralRL(WithTxHandler(UnlockWriteMarker))). Methods(http.MethodDelete, http.MethodOptions) r.HandleFunc("/v1/hashnode/root/{allocation}", - RateLimitByObjectRL(WithHandler(LoadRootHashnode))). + RateLimitByObjectRL(WithTxHandler(LoadRootHashnode))). Methods(http.MethodGet, http.MethodOptions) r.HandleFunc("/v1/playlist/latest/{allocation}", - RateLimitByGeneralRL(WithHandler(LoadPlaylist))). + RateLimitByGeneralRL(WithTxHandler(LoadPlaylist))). Methods(http.MethodGet, http.MethodOptions) r.HandleFunc("/v1/playlist/file/{allocation}", - RateLimitByGeneralRL(WithHandler(LoadPlaylistFile))). + RateLimitByGeneralRL(WithTxHandler(LoadPlaylistFile))). Methods(http.MethodGet, http.MethodOptions) } @@ -281,31 +281,18 @@ func WithReadOnlyConnection(handler common.JSONResponderF) common.JSONResponderF } func WithConnection(handler common.JSONResponderF) common.JSONResponderF { - return func(ctx context.Context, r *http.Request) (resp interface{}, err error) { - ctx = GetMetaDataStore().CreateTransaction(ctx) - - resp, err = handler(ctx, r) + return func(ctx context.Context, r *http.Request) (interface{}, error) { + var ( + resp interface{} + err error + ) + err = datastore.GetStore().WithNewTransaction(func(ctx context.Context) error { + resp, err = handler(ctx, r) - defer func() { - if err != nil { - var rollErr = GetMetaDataStore().GetTransaction(ctx). - Rollback().Error - if rollErr != nil { - Logger.Error("couldn't rollback", zap.Error(err)) - } - } - }() + return err + }) - if err != nil { - Logger.Error("Error in handling the request." + err.Error()) - return - } - err = GetMetaDataStore().GetTransaction(ctx).Commit().Error - if err != nil { - return resp, common.NewErrorf("commit_error", - "error committing to meta store: %v", err) - } - return + return resp, err } } diff --git a/code/go/0chain.net/blobbercore/handler/handler_hashnode_test.go b/code/go/0chain.net/blobbercore/handler/handler_hashnode_test.go index c6b7d0279..a44204ad5 100644 --- a/code/go/0chain.net/blobbercore/handler/handler_hashnode_test.go +++ b/code/go/0chain.net/blobbercore/handler/handler_hashnode_test.go @@ -85,7 +85,7 @@ FROM reference_objects`). } rr := httptest.NewRecorder() - handler := http.HandlerFunc(WithHandler(func(ctx *Context) (interface{}, error) { + handler := http.HandlerFunc(WithTxHandler(func(ctx *Context) (interface{}, error) { ctx.AllocationId = "allocation_handler_load_root" return LoadRootHashnode(ctx) })) diff --git a/code/go/0chain.net/blobbercore/handler/handler_playlist_test.go b/code/go/0chain.net/blobbercore/handler/handler_playlist_test.go index adcb4700c..f79bffb27 100644 --- a/code/go/0chain.net/blobbercore/handler/handler_playlist_test.go +++ b/code/go/0chain.net/blobbercore/handler/handler_playlist_test.go @@ -53,7 +53,7 @@ func TestPlaylist_LoadPlaylist(t *testing.T) { } rr := httptest.NewRecorder() - handler := http.HandlerFunc(WithHandler(func(ctx *Context) (interface{}, error) { + handler := http.HandlerFunc(WithTxHandler(func(ctx *Context) (interface{}, error) { ctx.AllocationId = "AllocationId" ctx.ClientID = "ownerid" ctx.Allocation = &allocation.Allocation{ @@ -123,7 +123,7 @@ func TestPlaylist_LoadPlaylistFile(t *testing.T) { } rr := httptest.NewRecorder() - handler := http.HandlerFunc(WithHandler(func(ctx *Context) (interface{}, error) { + handler := http.HandlerFunc(WithTxHandler(func(ctx *Context) (interface{}, error) { ctx.AllocationId = "AllocationId" ctx.ClientID = "ownerid" ctx.Allocation = &allocation.Allocation{ diff --git a/code/go/0chain.net/blobbercore/handler/handler_writemarker_test.go b/code/go/0chain.net/blobbercore/handler/handler_writemarker_test.go index 69e7b1544..9e4232aa5 100644 --- a/code/go/0chain.net/blobbercore/handler/handler_writemarker_test.go +++ b/code/go/0chain.net/blobbercore/handler/handler_writemarker_test.go @@ -41,7 +41,7 @@ func TestWriteMarkerHandlers_Lock(t *testing.T) { req.Header.Set("Content-Type", formWriter.FormDataContentType()) rr := httptest.NewRecorder() - handler := http.HandlerFunc(WithHandler(func(ctx *Context) (interface{}, error) { + handler := http.HandlerFunc(WithTxHandler(func(ctx *Context) (interface{}, error) { ctx.AllocationId = "TestHandlers_Lock_allocation_id" return LockWriteMarker(ctx) })) @@ -81,7 +81,7 @@ func TestWriteMarkerHandlers_Unlock(t *testing.T) { req.Header.Set("Content-Type", formWriter.FormDataContentType()) rr := httptest.NewRecorder() - handler := http.HandlerFunc(WithHandler(func(ctx *Context) (interface{}, error) { + handler := http.HandlerFunc(WithTxHandler(func(ctx *Context) (interface{}, error) { ctx.AllocationId = "TestHandlers_Unlock_allocation_id" ctx.Vars["connection"] = "connection_id" return UnlockWriteMarker(ctx) diff --git a/code/go/0chain.net/blobbercore/reference/referencepath.go b/code/go/0chain.net/blobbercore/reference/referencepath.go index 330ecaaa0..d79499865 100644 --- a/code/go/0chain.net/blobbercore/reference/referencepath.go +++ b/code/go/0chain.net/blobbercore/reference/referencepath.go @@ -335,24 +335,33 @@ func GetUpdatedRefs(ctx context.Context, allocationID, path, offsetPath, _type, return err }) - logging.Logger.Error("error", zap.Error(err)) + if err != nil { + logging.Logger.Error("error", zap.Error(err)) + } }() go func() { - tx := datastore.GetStore().GetTransaction(ctx) - db2 := tx.Model(&Ref{}).Where("allocation_id = ?", allocationID). - Where("path = ?", path).Or("path LIKE ?", path+"%") - if _type != "" { - db2 = db2.Where("type > ?", level) - } - if level != 0 { - db2 = db2.Where("level = ?", level) - } - if updatedDate != "" { - db2 = db2.Where("updated_at > ?", updatedDate) + err := datastore.GetStore().WithNewTransaction(func(ctx context.Context) error { + tx := datastore.GetStore().GetTransaction(ctx) + db2 := tx.Model(&Ref{}).Where("allocation_id = ?", allocationID). + Where("path = ?", path).Or("path LIKE ?", path+"%") + if _type != "" { + db2 = db2.Where("type > ?", level) + } + if level != 0 { + db2 = db2.Where("level = ?", level) + } + if updatedDate != "" { + db2 = db2.Where("updated_at > ?", updatedDate) + } + err = db2.Count(&totalRows).Error + wg.Done() + + return err + }) + if err != nil { + logging.Logger.Error("error", zap.Error(err)) } - db2 = db2.Count(&totalRows) - wg.Done() }() wg.Wait() if err != nil {