From 62230f6c8485aa4ad8aa80bf7dfa134ee1c16507 Mon Sep 17 00:00:00 2001 From: black Date: Mon, 13 Mar 2023 13:39:28 +0800 Subject: [PATCH 1/5] fix: nested scan (#6136) --- callbacks/row.go | 3 +-- finisher_api.go | 4 ++-- gorm.go | 16 ++++++++++++++++ statement.go | 6 ++++++ tests/scopes_test.go | 11 +++++++++++ 5 files changed, 36 insertions(+), 4 deletions(-) diff --git a/callbacks/row.go b/callbacks/row.go index beaa189e1..19510716c 100644 --- a/callbacks/row.go +++ b/callbacks/row.go @@ -11,8 +11,7 @@ func RowQuery(db *gorm.DB) { return } - if isRows, ok := db.Get("rows"); ok && isRows.(bool) { - db.Statement.Settings.Delete("rows") + if isRows := db.PopQueryType(); isRows { db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) } else { db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) diff --git a/finisher_api.go b/finisher_api.go index e6fe46663..0fa7d79f6 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -499,7 +499,7 @@ func (db *DB) Count(count *int64) (tx *DB) { } func (db *DB) Row() *sql.Row { - tx := db.getInstance().Set("rows", false) + tx := db.getInstance().PushQueryType(false) tx = tx.callbacks.Row().Execute(tx) row, ok := tx.Statement.Dest.(*sql.Row) if !ok && tx.DryRun { @@ -509,7 +509,7 @@ func (db *DB) Row() *sql.Row { } func (db *DB) Rows() (*sql.Rows, error) { - tx := db.getInstance().Set("rows", true) + tx := db.getInstance().PushQueryType(true) tx = tx.callbacks.Row().Execute(tx) rows, ok := tx.Statement.Dest.(*sql.Rows) if !ok && tx.DryRun && tx.Error == nil { diff --git a/gorm.go b/gorm.go index 9a70c3d21..25502b664 100644 --- a/gorm.go +++ b/gorm.go @@ -340,6 +340,22 @@ func (db *DB) InstanceGet(key string) (interface{}, bool) { return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key) } +func (db *DB) PushQueryType(rows bool) *DB { + tx := db.getInstance() + tx.Statement.queryTypes = append(tx.Statement.queryTypes, rows) + return tx +} + +func (db *DB) PopQueryType() bool { + length := len(db.Statement.queryTypes) + if length == 0 { + return false + } + value := db.Statement.queryTypes[length-1] + db.Statement.queryTypes = db.Statement.queryTypes[:length-1] + return value +} + // Callback returns callback manager func (db *DB) Callback() *callbacks { return db.callbacks diff --git a/statement.go b/statement.go index bc959f0b6..095a3c136 100644 --- a/statement.go +++ b/statement.go @@ -46,6 +46,7 @@ type Statement struct { attrs []interface{} assigns []interface{} scopes []func(*DB) *DB + queryTypes []bool } type join struct { @@ -543,6 +544,11 @@ func (stmt *Statement) clone() *Statement { copy(newStmt.scopes, stmt.scopes) } + if len(stmt.queryTypes) > 0 { + newStmt.queryTypes = make([]bool, len(stmt.queryTypes)) + copy(newStmt.queryTypes, stmt.queryTypes) + } + stmt.Settings.Range(func(k, v interface{}) bool { newStmt.Settings.Store(k, v) return true diff --git a/tests/scopes_test.go b/tests/scopes_test.go index ab3807ea2..61f4ef3c8 100644 --- a/tests/scopes_test.go +++ b/tests/scopes_test.go @@ -71,4 +71,15 @@ func TestScopes(t *testing.T) { if err := DB.Scopes(userTable).Select("max(id)").Scan(&maxId).Error; err != nil { t.Errorf("select max(id)") } + + var user User + if err := DB.Scopes(func(db *gorm.DB) *gorm.DB { + var maxID int64 + if err := db.Raw("select max(id) from users").Scan(&maxID).Error; err != nil { + return db + } + return db.Raw("select * from users where id = ?", maxID) + }).Scan(&user).Error; err != nil { + t.Errorf("failed to find user, got err: %v", err) + } } From aec9023a104189539996c23521cfffc66200f367 Mon Sep 17 00:00:00 2001 From: black Date: Mon, 13 Mar 2023 16:34:50 +0800 Subject: [PATCH 2/5] add mutex --- callbacks/row.go | 2 +- finisher_api.go | 6 ++++-- gorm.go | 16 ---------------- statement.go | 44 +++++++++++++++++++++++++++++++++++++++----- tests/scopes_test.go | 10 ++++++++++ 5 files changed, 54 insertions(+), 24 deletions(-) diff --git a/callbacks/row.go b/callbacks/row.go index 19510716c..77c93e78f 100644 --- a/callbacks/row.go +++ b/callbacks/row.go @@ -11,7 +11,7 @@ func RowQuery(db *gorm.DB) { return } - if isRows := db.PopQueryType(); isRows { + if isRows := db.Statement.QueryTypes.Pop(); isRows { db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) } else { db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) diff --git a/finisher_api.go b/finisher_api.go index 0fa7d79f6..935a0268e 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -499,7 +499,8 @@ func (db *DB) Count(count *int64) (tx *DB) { } func (db *DB) Row() *sql.Row { - tx := db.getInstance().PushQueryType(false) + tx := db.getInstance() + tx.Statement.QueryTypes.Push(false) tx = tx.callbacks.Row().Execute(tx) row, ok := tx.Statement.Dest.(*sql.Row) if !ok && tx.DryRun { @@ -509,7 +510,8 @@ func (db *DB) Row() *sql.Row { } func (db *DB) Rows() (*sql.Rows, error) { - tx := db.getInstance().PushQueryType(true) + tx := db.getInstance() + tx.Statement.QueryTypes.Push(true) tx = tx.callbacks.Row().Execute(tx) rows, ok := tx.Statement.Dest.(*sql.Rows) if !ok && tx.DryRun && tx.Error == nil { diff --git a/gorm.go b/gorm.go index 25502b664..9a70c3d21 100644 --- a/gorm.go +++ b/gorm.go @@ -340,22 +340,6 @@ func (db *DB) InstanceGet(key string) (interface{}, bool) { return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key) } -func (db *DB) PushQueryType(rows bool) *DB { - tx := db.getInstance() - tx.Statement.queryTypes = append(tx.Statement.queryTypes, rows) - return tx -} - -func (db *DB) PopQueryType() bool { - length := len(db.Statement.queryTypes) - if length == 0 { - return false - } - value := db.Statement.queryTypes[length-1] - db.Statement.queryTypes = db.Statement.queryTypes[:length-1] - return value -} - // Callback returns callback manager func (db *DB) Callback() *callbacks { return db.callbacks diff --git a/statement.go b/statement.go index 095a3c136..497cfedcc 100644 --- a/statement.go +++ b/statement.go @@ -34,6 +34,7 @@ type Statement struct { Omits []string // omit columns Joins []join Preloads map[string][]interface{} + QueryTypes QueryTypes Settings sync.Map ConnPool ConnPool Schema *schema.Schema @@ -46,7 +47,6 @@ type Statement struct { attrs []interface{} assigns []interface{} scopes []func(*DB) *DB - queryTypes []bool } type join struct { @@ -58,6 +58,43 @@ type join struct { JoinType clause.JoinType } +type QueryTypes struct { + mux sync.Mutex + values []bool +} + +func (q *QueryTypes) Push(isRows bool) { + q.mux.Lock() + defer q.mux.Unlock() + q.values = append(q.values, isRows) +} + +func (q *QueryTypes) Pop() bool { + q.mux.Lock() + defer q.mux.Unlock() + + if len(q.values) == 0 { + return false + } + + value := q.values[len(q.values)-1] + q.values = q.values[:len(q.values)-1] + return value +} + +func (q *QueryTypes) clone() QueryTypes { + q.mux.Lock() + defer q.mux.Unlock() + + if len(q.values) == 0 { + return QueryTypes{} + } + + values := make([]bool, len(q.values)) + copy(values, q.values) + return QueryTypes{values: values} +} + // StatementModifier statement modifier interface type StatementModifier interface { ModifyStatement(*Statement) @@ -544,10 +581,7 @@ func (stmt *Statement) clone() *Statement { copy(newStmt.scopes, stmt.scopes) } - if len(stmt.queryTypes) > 0 { - newStmt.queryTypes = make([]bool, len(stmt.queryTypes)) - copy(newStmt.queryTypes, stmt.queryTypes) - } + newStmt.QueryTypes = stmt.QueryTypes.clone() stmt.Settings.Range(func(k, v interface{}) bool { newStmt.Settings.Store(k, v) diff --git a/tests/scopes_test.go b/tests/scopes_test.go index 61f4ef3c8..25257918e 100644 --- a/tests/scopes_test.go +++ b/tests/scopes_test.go @@ -82,4 +82,14 @@ func TestScopes(t *testing.T) { }).Scan(&user).Error; err != nil { t.Errorf("failed to find user, got err: %v", err) } + + if err := DB.Scopes(func(db *gorm.DB) *gorm.DB { + var maxID int64 + if err := db.Model(&User{}).Select("max(id)").Scan(&maxID).Error; err != nil { + return db + } + return db.Where("id = ?", maxID) + }).Scan(&user).Error; err != nil { + t.Errorf("failed to find user, got err: %v", err) + } } From 5727808a2074755c0337bbdb642d9fdecadf4a8d Mon Sep 17 00:00:00 2001 From: black Date: Mon, 13 Mar 2023 16:49:52 +0800 Subject: [PATCH 3/5] fix test --- tests/scopes_test.go | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/tests/scopes_test.go b/tests/scopes_test.go index 25257918e..61f4ef3c8 100644 --- a/tests/scopes_test.go +++ b/tests/scopes_test.go @@ -82,14 +82,4 @@ func TestScopes(t *testing.T) { }).Scan(&user).Error; err != nil { t.Errorf("failed to find user, got err: %v", err) } - - if err := DB.Scopes(func(db *gorm.DB) *gorm.DB { - var maxID int64 - if err := db.Model(&User{}).Select("max(id)").Scan(&maxID).Error; err != nil { - return db - } - return db.Where("id = ?", maxID) - }).Scan(&user).Error; err != nil { - t.Errorf("failed to find user, got err: %v", err) - } } From 92a360708d0f201857a801bf9f5848be827e7877 Mon Sep 17 00:00:00 2001 From: black Date: Mon, 13 Mar 2023 17:04:55 +0800 Subject: [PATCH 4/5] use container/list to prevent memory leaks --- statement.go | 32 ++++++++++++++++++++------------ statement_test.go | 20 ++++++++++++++++++++ 2 files changed, 40 insertions(+), 12 deletions(-) diff --git a/statement.go b/statement.go index 497cfedcc..25029c485 100644 --- a/statement.go +++ b/statement.go @@ -1,6 +1,7 @@ package gorm import ( + "container/list" "context" "database/sql" "database/sql/driver" @@ -59,40 +60,47 @@ type join struct { } type QueryTypes struct { - mux sync.Mutex - values []bool + mux sync.Mutex + list *list.List } func (q *QueryTypes) Push(isRows bool) { q.mux.Lock() defer q.mux.Unlock() - q.values = append(q.values, isRows) + if q.list == nil { + q.list = list.New() + } + q.list.PushBack(isRows) } func (q *QueryTypes) Pop() bool { q.mux.Lock() defer q.mux.Unlock() - if len(q.values) == 0 { + if q.list == nil { return false } - - value := q.values[len(q.values)-1] - q.values = q.values[:len(q.values)-1] - return value + element := q.list.Back() + if element == nil { + return false + } + q.list.Remove(element) + return element.Value.(bool) } func (q *QueryTypes) clone() QueryTypes { q.mux.Lock() defer q.mux.Unlock() - if len(q.values) == 0 { + if q.list == nil { return QueryTypes{} } - values := make([]bool, len(q.values)) - copy(values, q.values) - return QueryTypes{values: values} + cloneList := list.New() + for e := q.list.Front(); e != nil; e = e.Next() { + cloneList.PushFront(e.Value) + } + return QueryTypes{list: cloneList} } // StatementModifier statement modifier interface diff --git a/statement_test.go b/statement_test.go index 648bc875d..a6b5f1c50 100644 --- a/statement_test.go +++ b/statement_test.go @@ -62,3 +62,23 @@ func TestNameMatcher(t *testing.T) { } } } + +func TestQueryTypes(t *testing.T) { + types := QueryTypes{} + values := []bool{true, false, false, true} + for _, value := range values { + types.Push(value) + } + + clone := types.clone() + for _, value := range values { + actual := clone.Pop() + if actual != value { + t.Errorf("failed to pop, got %v, expect %v", actual, value) + } + } + + if clone.list.Len() != 0 || clone.Pop() { + t.Errorf("clone list should be empty") + } +} From 672c48b74c5099c5a5f80ea7ffed2289d5508ec6 Mon Sep 17 00:00:00 2001 From: black Date: Mon, 13 Mar 2023 17:33:57 +0800 Subject: [PATCH 5/5] avoid adding attributes --- callbacks/row.go | 2 +- finisher_api.go | 10 ++++++++-- statement.go | 19 ++++++++++++------- statement_test.go | 4 ++-- 4 files changed, 23 insertions(+), 12 deletions(-) diff --git a/callbacks/row.go b/callbacks/row.go index 77c93e78f..5893ee2a2 100644 --- a/callbacks/row.go +++ b/callbacks/row.go @@ -11,7 +11,7 @@ func RowQuery(db *gorm.DB) { return } - if isRows := db.Statement.QueryTypes.Pop(); isRows { + if types, ok := db.Statement.Settings.Load("rows"); ok && types.(*gorm.QueryTypes).Pop() { db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) } else { db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) diff --git a/finisher_api.go b/finisher_api.go index 935a0268e..62e6523fa 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -500,7 +500,10 @@ func (db *DB) Count(count *int64) (tx *DB) { func (db *DB) Row() *sql.Row { tx := db.getInstance() - tx.Statement.QueryTypes.Push(false) + + value, _ := tx.Statement.Settings.LoadOrStore("rows", &QueryTypes{}) + value.(*QueryTypes).Push(false) + tx = tx.callbacks.Row().Execute(tx) row, ok := tx.Statement.Dest.(*sql.Row) if !ok && tx.DryRun { @@ -511,7 +514,10 @@ func (db *DB) Row() *sql.Row { func (db *DB) Rows() (*sql.Rows, error) { tx := db.getInstance() - tx.Statement.QueryTypes.Push(true) + + value, _ := tx.Statement.Settings.LoadOrStore("rows", &QueryTypes{}) + value.(*QueryTypes).Push(true) + tx = tx.callbacks.Row().Execute(tx) rows, ok := tx.Statement.Dest.(*sql.Rows) if !ok && tx.DryRun && tx.Error == nil { diff --git a/statement.go b/statement.go index 25029c485..7f1954b9c 100644 --- a/statement.go +++ b/statement.go @@ -35,7 +35,6 @@ type Statement struct { Omits []string // omit columns Joins []join Preloads map[string][]interface{} - QueryTypes QueryTypes Settings sync.Map ConnPool ConnPool Schema *schema.Schema @@ -88,19 +87,19 @@ func (q *QueryTypes) Pop() bool { return element.Value.(bool) } -func (q *QueryTypes) clone() QueryTypes { +func (q *QueryTypes) Clone() interface{} { q.mux.Lock() defer q.mux.Unlock() if q.list == nil { - return QueryTypes{} + return &QueryTypes{} } cloneList := list.New() for e := q.list.Front(); e != nil; e = e.Next() { cloneList.PushFront(e.Value) } - return QueryTypes{list: cloneList} + return &QueryTypes{list: cloneList} } // StatementModifier statement modifier interface @@ -589,16 +588,22 @@ func (stmt *Statement) clone() *Statement { copy(newStmt.scopes, stmt.scopes) } - newStmt.QueryTypes = stmt.QueryTypes.clone() - stmt.Settings.Range(func(k, v interface{}) bool { - newStmt.Settings.Store(k, v) + if cloneable, ok := v.(Cloneable); ok { + newStmt.Settings.Store(k, cloneable.Clone()) + } else { + newStmt.Settings.Store(k, v) + } return true }) return newStmt } +type Cloneable interface { + Clone() interface{} +} + // SetColumn set column's value // // stmt.SetColumn("Name", "jinzhu") // Hooks Method diff --git a/statement_test.go b/statement_test.go index a6b5f1c50..84ccc2a56 100644 --- a/statement_test.go +++ b/statement_test.go @@ -64,13 +64,13 @@ func TestNameMatcher(t *testing.T) { } func TestQueryTypes(t *testing.T) { - types := QueryTypes{} + types := &QueryTypes{} values := []bool{true, false, false, true} for _, value := range values { types.Push(value) } - clone := types.clone() + clone := types.Clone().(*QueryTypes) for _, value := range values { actual := clone.Pop() if actual != value {