diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md index 8b4f03b77..a0b64bfa7 100644 --- a/.github/ISSUE_TEMPLATE.md +++ b/.github/ISSUE_TEMPLATE.md @@ -1,10 +1,4 @@ -Before posting a bug report about a problem, please try to verify that it is a bug and that it has not been reported already, please apply corresponding GitHub labels to the issue, for feature requests, please apply `type:feature`. - -DON'T post usage related questions, ask in https://gitter.im/jinzhu/gorm or http://stackoverflow.com/questions/tagged/go-gorm, - -Please answer these questions before submitting your issue. Thanks! - - +Your issue may already be reported! Please search on the [issue track](https://github.com/jinzhu/gorm/issues) before creating one. ### What version of Go are you using (`go version`)? @@ -12,9 +6,9 @@ Please answer these questions before submitting your issue. Thanks! ### Which database and its version are you using? -### What did you do? +### Please provide a complete runnable program to reproduce your issue. **IMPORTANT** -Please provide a complete runnable program to reproduce your issue. +Need to runnable with [GORM's docker compose config](https://github.com/jinzhu/gorm/blob/master/docker-compose.yml) or please provides your config. ```go package main @@ -32,10 +26,9 @@ var db *gorm.DB func init() { var err error db, err = gorm.Open("sqlite3", "test.db") - // Please use below username, password as your database's account for the script. - // db, err = gorm.Open("postgres", "user=gorm dbname=gorm sslmode=disable") - // db, err = gorm.Open("mysql", "gorm:gorm@/dbname?charset=utf8&parseTime=True") - // db, err = gorm.Open("mssql", "sqlserver://gorm:LoremIpsum86@localhost:1433?database=gorm") + // db, err = gorm.Open("postgres", "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable") + // db, err = gorm.Open("mysql", "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True") + // db, err = gorm.Open("mssql", "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm") if err != nil { panic(err) } @@ -43,8 +36,6 @@ func init() { } func main() { - // your code here - if /* failure condition */ { fmt.Println("failed") } else { diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 0ee0d73b3..b467b6ce9 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -3,12 +3,7 @@ Make sure these boxes checked before submitting your pull request. - [] Do only one thing - [] No API-breaking changes - [] New code/logic commented & tested -- [] Write good commit message, try to squash your commits into a single one -- [] Run `./build.sh` in `gh-pages` branch for document changes - -For significant changes like big bug fixes, new features, please open an issue to make a agreement on an implementation design/plan first before starting it. - -Thank you. +For significant changes like big bug fixes, new features, please open an issue to make an agreement on an implementation design/plan first before starting it. ### What did this pull request do? diff --git a/README.md b/README.md index e5c21dc5c..0c5c7ea6c 100644 --- a/README.md +++ b/README.md @@ -2,16 +2,19 @@ The fantastic ORM library for Golang, aims to be developer friendly. -[![Join the chat at https://gitter.im/jinzhu/gorm](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) [![go report card](https://goreportcard.com/badge/github.com/jinzhu/gorm "go report card")](https://goreportcard.com/report/github.com/jinzhu/gorm) -[![wercker status](https://app.wercker.com/status/0cb7bb1039e21b74f8274941428e0921/s/master "wercker status")](https://app.wercker.com/project/bykey/0cb7bb1039e21b74f8274941428e0921) +[![wercker status](https://app.wercker.com/status/8596cace912c9947dd9c8542ecc8cb8b/s/master "wercker status")](https://app.wercker.com/project/byKey/8596cace912c9947dd9c8542ecc8cb8b) +[![Join the chat at https://gitter.im/jinzhu/gorm](https://img.shields.io/gitter/room/jinzhu/gorm.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) +[![Open Collective Backer](https://opencollective.com/gorm/tiers/backer/badge.svg?label=backer&color=brightgreen "Open Collective Backer")](https://opencollective.com/gorm) +[![Open Collective Sponsor](https://opencollective.com/gorm/tiers/sponsor/badge.svg?label=sponsor&color=brightgreen "Open Collective Sponsor")](https://opencollective.com/gorm) +[![MIT license](http://img.shields.io/badge/license-MIT-brightgreen.svg)](http://opensource.org/licenses/MIT) [![GoDoc](https://godoc.org/github.com/jinzhu/gorm?status.svg)](https://godoc.org/github.com/jinzhu/gorm) ## Overview * Full-Featured ORM (almost) * Associations (Has One, Has Many, Belongs To, Many To Many, Polymorphism) -* Callbacks (Before/After Create/Save/Update/Delete/Find) +* Hooks (Before/After Create/Save/Update/Delete/Find) * Preloading (eager loading) * Transactions * Composite Primary Key @@ -24,28 +27,14 @@ The fantastic ORM library for Golang, aims to be developer friendly. ## Getting Started -* GORM Guides [jinzhu.github.com/gorm](http://jinzhu.github.io/gorm) +* GORM Guides [http://gorm.io](http://gorm.io) -## Upgrading To V1.0 +## Contributing -* [CHANGELOG](http://jinzhu.github.io/gorm/changelog.html) - -## Supporting the project - -[![http://patreon.com/jinzhu](https://c5.patreon.com/external/logo/become_a_patron_button.png)](http://patreon.com/jinzhu) - -## Author - -**jinzhu** - -* -* -* - -## Contributors - -https://github.com/jinzhu/gorm/graphs/contributors +[You can help to deliver a better GORM, check out things you can do](http://gorm.io/contribute.html) ## License -Released under the [MIT License](https://github.com/jinzhu/gorm/blob/master/License). +© Jinzhu, 2013~time.Now + +Released under the [MIT License](https://github.com/jinzhu/gorm/blob/master/License) diff --git a/association.go b/association.go index 3d522ccc2..8c6d9864b 100644 --- a/association.go +++ b/association.go @@ -107,7 +107,7 @@ func (association *Association) Replace(values ...interface{}) *Association { if sourcePrimaryKeys := scope.getColumnAsArray(sourceForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 { newDB = newDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(sourcePrimaryKeys)), toQueryValues(sourcePrimaryKeys)...) - association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship)) + association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB)) } } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { // has_one or has_many relations, set foreign key to be nil (TODO or delete them?) @@ -173,7 +173,7 @@ func (association *Association) Delete(values ...interface{}) *Association { sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys)) newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...) - association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship)) + association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB)) } else { var foreignKeyMap = map[string]interface{}{} for _, foreignKey := range relationship.ForeignDBNames { diff --git a/association_test.go b/association_test.go index c84f84ed6..60d0cf485 100644 --- a/association_test.go +++ b/association_test.go @@ -885,7 +885,7 @@ func TestHasManyChildrenWithOneStruct(t *testing.T) { DB.Save(&category) } -func TestSkipSaveAssociation(t *testing.T) { +func TestAutoSaveBelongsToAssociation(t *testing.T) { type Company struct { gorm.Model Name string @@ -895,13 +895,156 @@ func TestSkipSaveAssociation(t *testing.T) { gorm.Model Name string CompanyID uint - Company Company `gorm:"save_associations:false"` + Company Company `gorm:"association_autoupdate:false;association_autocreate:false;"` } + + DB.Where("name = ?", "auto_save_association").Delete(&Company{}) + DB.AutoMigrate(&Company{}, &User{}) + + DB.Save(&User{Name: "jinzhu", Company: Company{Name: "auto_save_association"}}) + + if !DB.Where("name = ?", "auto_save_association").First(&Company{}).RecordNotFound() { + t.Errorf("Company auto_save_association should not have been saved when autosave is false") + } + + // if foreign key is set, this should be saved even if association isn't + company := Company{Name: "auto_save_association"} + DB.Save(&company) + + company.Name = "auto_save_association_new_name" + user := User{Name: "jinzhu", Company: company} + + DB.Save(&user) + + if !DB.Where("name = ?", "auto_save_association_new_name").First(&Company{}).RecordNotFound() { + t.Errorf("Company should not have been updated") + } + + if DB.Where("id = ? AND company_id = ?", user.ID, company.ID).First(&User{}).RecordNotFound() { + t.Errorf("User's foreign key should have been saved") + } + + user2 := User{Name: "jinzhu_2", Company: Company{Name: "auto_save_association_2"}} + DB.Set("gorm:association_autocreate", true).Save(&user2) + if DB.Where("name = ?", "auto_save_association_2").First(&Company{}).RecordNotFound() { + t.Errorf("Company auto_save_association_2 should been created when autocreate is true") + } + + user2.Company.Name = "auto_save_association_2_newname" + DB.Set("gorm:association_autoupdate", true).Save(&user2) + + if DB.Where("name = ?", "auto_save_association_2_newname").First(&Company{}).RecordNotFound() { + t.Errorf("Company should been updated") + } +} + +func TestAutoSaveHasOneAssociation(t *testing.T) { + type Company struct { + gorm.Model + UserID uint + Name string + } + + type User struct { + gorm.Model + Name string + Company Company `gorm:"association_autoupdate:false;association_autocreate:false;"` + } + + DB.Where("name = ?", "auto_save_has_one_association").Delete(&Company{}) + DB.AutoMigrate(&Company{}, &User{}) + + DB.Save(&User{Name: "jinzhu", Company: Company{Name: "auto_save_has_one_association"}}) + + if !DB.Where("name = ?", "auto_save_has_one_association").First(&Company{}).RecordNotFound() { + t.Errorf("Company auto_save_has_one_association should not have been saved when autosave is false") + } + + company := Company{Name: "auto_save_has_one_association"} + DB.Save(&company) + + company.Name = "auto_save_has_one_association_new_name" + user := User{Name: "jinzhu", Company: company} + + DB.Save(&user) + + if !DB.Where("name = ?", "auto_save_has_one_association_new_name").First(&Company{}).RecordNotFound() { + t.Errorf("Company should not have been updated") + } + + if !DB.Where("name = ? AND user_id = ?", "auto_save_has_one_association", user.ID).First(&Company{}).RecordNotFound() { + t.Errorf("Company should not have been updated") + } + + if user.Company.UserID == 0 { + t.Errorf("UserID should be assigned") + } + + company.Name = "auto_save_has_one_association_2_new_name" + DB.Set("gorm:association_autoupdate", true).Save(&user) + + if DB.Where("name = ? AND user_id = ?", "auto_save_has_one_association_new_name", user.ID).First(&Company{}).RecordNotFound() { + t.Errorf("Company should been updated") + } + + user2 := User{Name: "jinzhu_2", Company: Company{Name: "auto_save_has_one_association_2"}} + DB.Set("gorm:association_autocreate", true).Save(&user2) + if DB.Where("name = ?", "auto_save_has_one_association_2").First(&Company{}).RecordNotFound() { + t.Errorf("Company auto_save_has_one_association_2 should been created when autocreate is true") + } +} + +func TestAutoSaveMany2ManyAssociation(t *testing.T) { + type Company struct { + gorm.Model + Name string + } + + type User struct { + gorm.Model + Name string + Companies []Company `gorm:"many2many:user_companies;association_autoupdate:false;association_autocreate:false;"` + } + DB.AutoMigrate(&Company{}, &User{}) - DB.Save(&User{Name: "jinzhu", Company: Company{Name: "skip_save_association"}}) + DB.Save(&User{Name: "jinzhu", Companies: []Company{{Name: "auto_save_m2m_association"}}}) + + if !DB.Where("name = ?", "auto_save_m2m_association").First(&Company{}).RecordNotFound() { + t.Errorf("Company auto_save_m2m_association should not have been saved when autosave is false") + } + + company := Company{Name: "auto_save_m2m_association"} + DB.Save(&company) + + company.Name = "auto_save_m2m_association_new_name" + user := User{Name: "jinzhu", Companies: []Company{company, {Name: "auto_save_m2m_association_new_name_2"}}} + + DB.Save(&user) + + if !DB.Where("name = ?", "auto_save_m2m_association_new_name").First(&Company{}).RecordNotFound() { + t.Errorf("Company should not have been updated") + } + + if !DB.Where("name = ?", "auto_save_m2m_association_new_name_2").First(&Company{}).RecordNotFound() { + t.Errorf("Company should not been created") + } + + if DB.Model(&user).Association("Companies").Count() != 1 { + t.Errorf("Relationship should been saved") + } + + DB.Set("gorm:association_autoupdate", true).Set("gorm:association_autocreate", true).Save(&user) + + if DB.Where("name = ?", "auto_save_m2m_association_new_name").First(&Company{}).RecordNotFound() { + t.Errorf("Company should been updated") + } + + if DB.Where("name = ?", "auto_save_m2m_association_new_name_2").First(&Company{}).RecordNotFound() { + t.Errorf("Company should been created") + } - if !DB.Where("name = ?", "skip_save_association").First(&Company{}).RecordNotFound() { - t.Errorf("Company skip_save_association should not been saved") + if DB.Model(&user).Association("Companies").Count() != 2 { + t.Errorf("Relationship should been updated") } } diff --git a/callback_create.go b/callback_create.go index a4da39e84..e7fe6f868 100644 --- a/callback_create.go +++ b/callback_create.go @@ -97,8 +97,9 @@ func createCallback(scope *Scope) { if len(columns) == 0 { scope.Raw(fmt.Sprintf( - "INSERT INTO %v DEFAULT VALUES%v%v", + "INSERT INTO %v %v%v%v", quotedTableName, + scope.Dialect().DefaultValueStr(), addExtraSpaceIfExist(extraOption), addExtraSpaceIfExist(lastInsertIDReturningSuffix), )) diff --git a/callback_query.go b/callback_query.go index 20e881616..ba10cc7db 100644 --- a/callback_query.go +++ b/callback_query.go @@ -15,6 +15,10 @@ func init() { // queryCallback used to query data from database func queryCallback(scope *Scope) { + if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip { + return + } + defer scope.trace(NowFunc()) var ( diff --git a/callback_query_preload.go b/callback_query_preload.go index 21ab22ce1..30f6b585a 100644 --- a/callback_query_preload.go +++ b/callback_query_preload.go @@ -10,6 +10,9 @@ import ( // preloadCallback used to preload associations func preloadCallback(scope *Scope) { + if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip { + return + } if _, ok := scope.Get("gorm:auto_preload"); ok { autoPreload(scope) @@ -324,6 +327,10 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface scope.scan(rows, columns, append(fields, joinTableFields...)) + scope.New(elem.Addr().Interface()). + InstanceSet("gorm:skip_query_callback", true). + callCallbacks(scope.db.parent.callbacks.queries) + var foreignKeys = make([]interface{}, len(sourceKeys)) // generate hashed forkey keys in join table for idx, joinTableField := range joinTableFields { diff --git a/callback_save.go b/callback_save.go index f4bc918e7..ef2671418 100644 --- a/callback_save.go +++ b/callback_save.go @@ -1,6 +1,9 @@ package gorm -import "reflect" +import ( + "reflect" + "strings" +) func beginTransactionCallback(scope *Scope) { scope.Begin() @@ -10,31 +13,81 @@ func commitOrRollbackTransactionCallback(scope *Scope) { scope.CommitOrRollback() } -func saveFieldAsAssociation(scope *Scope, field *Field) (bool, *Relationship) { +func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCreate bool, saveReference bool, r *Relationship) { + checkTruth := func(value interface{}) bool { + if v, ok := value.(bool); ok && !v { + return false + } + + if v, ok := value.(string); ok { + v = strings.ToLower(v) + if v == "false" || v != "skip" { + return false + } + } + + return true + } + if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { - if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; !ok || (value != "false" && value != "skip") { - if relationship := field.Relationship; relationship != nil { - return true, relationship + if r = field.Relationship; r != nil { + autoUpdate, autoCreate, saveReference = true, true, true + + if value, ok := scope.Get("gorm:save_associations"); ok { + autoUpdate = checkTruth(value) + autoCreate = autoUpdate + } else if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; ok { + autoUpdate = checkTruth(value) + autoCreate = autoUpdate + } + + if value, ok := scope.Get("gorm:association_autoupdate"); ok { + autoUpdate = checkTruth(value) + } else if value, ok := field.TagSettings["ASSOCIATION_AUTOUPDATE"]; ok { + autoUpdate = checkTruth(value) + } + + if value, ok := scope.Get("gorm:association_autocreate"); ok { + autoCreate = checkTruth(value) + } else if value, ok := field.TagSettings["ASSOCIATION_AUTOCREATE"]; ok { + autoCreate = checkTruth(value) + } + + if value, ok := scope.Get("gorm:association_save_reference"); ok { + saveReference = checkTruth(value) + } else if value, ok := field.TagSettings["ASSOCIATION_SAVE_REFERENCE"]; ok { + saveReference = checkTruth(value) } } } - return false, nil + + return } func saveBeforeAssociationsCallback(scope *Scope) { - if !scope.shouldSaveAssociations() { - return - } for _, field := range scope.Fields() { - if ok, relationship := saveFieldAsAssociation(scope, field); ok && relationship.Kind == "belongs_to" { + autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field) + + if relationship != nil && relationship.Kind == "belongs_to" { fieldValue := field.Field.Addr().Interface() - scope.Err(scope.NewDB().Save(fieldValue).Error) - if len(relationship.ForeignFieldNames) != 0 { - // set value's foreign key - for idx, fieldName := range relationship.ForeignFieldNames { - associationForeignName := relationship.AssociationForeignDBNames[idx] - if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok { - scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface())) + newScope := scope.New(fieldValue) + + if newScope.PrimaryKeyZero() { + if autoCreate { + scope.Err(scope.NewDB().Save(fieldValue).Error) + } + } else if autoUpdate { + scope.Err(scope.NewDB().Save(fieldValue).Error) + } + + if saveReference { + if len(relationship.ForeignFieldNames) != 0 { + // set value's foreign key + for idx, fieldName := range relationship.ForeignFieldNames { + associationForeignName := relationship.AssociationForeignDBNames[idx] + if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok { + scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface())) + } } } } @@ -43,12 +96,10 @@ func saveBeforeAssociationsCallback(scope *Scope) { } func saveAfterAssociationsCallback(scope *Scope) { - if !scope.shouldSaveAssociations() { - return - } for _, field := range scope.Fields() { - if ok, relationship := saveFieldAsAssociation(scope, field); ok && - (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") { + autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field) + + if relationship != nil && (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") { value := field.Field switch value.Kind() { @@ -58,41 +109,61 @@ func saveAfterAssociationsCallback(scope *Scope) { elem := value.Index(i).Addr().Interface() newScope := newDB.NewScope(elem) - if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 { - for idx, fieldName := range relationship.ForeignFieldNames { - associationForeignName := relationship.AssociationForeignDBNames[idx] - if f, ok := scope.FieldByName(associationForeignName); ok { - scope.Err(newScope.SetColumn(fieldName, f.Field.Interface())) + if saveReference { + if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 { + for idx, fieldName := range relationship.ForeignFieldNames { + associationForeignName := relationship.AssociationForeignDBNames[idx] + if f, ok := scope.FieldByName(associationForeignName); ok { + scope.Err(newScope.SetColumn(fieldName, f.Field.Interface())) + } } } - } - if relationship.PolymorphicType != "" { - scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) + if relationship.PolymorphicType != "" { + scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) + } } - scope.Err(newDB.Save(elem).Error) + if newScope.PrimaryKeyZero() { + if autoCreate { + scope.Err(newDB.Save(elem).Error) + } + } else if autoUpdate { + scope.Err(newDB.Save(elem).Error) + } - if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil { - scope.Err(joinTableHandler.Add(joinTableHandler, newDB, scope.Value, newScope.Value)) + if !scope.New(newScope.Value).PrimaryKeyZero() && saveReference { + if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil { + scope.Err(joinTableHandler.Add(joinTableHandler, newDB, scope.Value, newScope.Value)) + } } } default: elem := value.Addr().Interface() newScope := scope.New(elem) - if len(relationship.ForeignFieldNames) != 0 { - for idx, fieldName := range relationship.ForeignFieldNames { - associationForeignName := relationship.AssociationForeignDBNames[idx] - if f, ok := scope.FieldByName(associationForeignName); ok { - scope.Err(newScope.SetColumn(fieldName, f.Field.Interface())) + + if saveReference { + if len(relationship.ForeignFieldNames) != 0 { + for idx, fieldName := range relationship.ForeignFieldNames { + associationForeignName := relationship.AssociationForeignDBNames[idx] + if f, ok := scope.FieldByName(associationForeignName); ok { + scope.Err(newScope.SetColumn(fieldName, f.Field.Interface())) + } } } + + if relationship.PolymorphicType != "" { + scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) + } } - if relationship.PolymorphicType != "" { - scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) + if newScope.PrimaryKeyZero() { + if autoCreate { + scope.Err(scope.NewDB().Save(elem).Error) + } + } else if autoUpdate { + scope.Err(scope.NewDB().Save(elem).Error) } - scope.Err(scope.NewDB().Save(elem).Error) } } } diff --git a/callback_update.go b/callback_update.go index 6948439fb..373bd7266 100644 --- a/callback_update.go +++ b/callback_update.go @@ -3,6 +3,7 @@ package gorm import ( "errors" "fmt" + "sort" "strings" ) @@ -59,7 +60,16 @@ func updateCallback(scope *Scope) { var sqls []string if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok { - for column, value := range updateAttrs.(map[string]interface{}) { + // Sort the column names so that the generated SQL is the same every time. + updateMap := updateAttrs.(map[string]interface{}) + var columns []string + for c := range updateMap { + columns = append(columns, c) + } + sort.Strings(columns) + + for _, column := range columns { + value := updateMap[column] sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value))) } } else { diff --git a/create_test.go b/create_test.go index 36472914b..925606437 100644 --- a/create_test.go +++ b/create_test.go @@ -27,7 +27,9 @@ func TestCreate(t *testing.T) { } var newUser User - DB.First(&newUser, user.Id) + if err := DB.First(&newUser, user.Id).Error; err != nil { + t.Errorf("No error should happen, but got %v", err) + } if !reflect.DeepEqual(newUser.PasswordHash, []byte{'f', 'a', 'k', '4'}) { t.Errorf("User's PasswordHash should be saved ([]byte)") @@ -38,7 +40,7 @@ func TestCreate(t *testing.T) { } if newUser.UserNum != Num(111) { - t.Errorf("User's UserNum should be saved (custom type)") + t.Errorf("User's UserNum should be saved (custom type), but got %v", newUser.UserNum) } if newUser.Latitude != float { @@ -60,6 +62,17 @@ func TestCreate(t *testing.T) { } } +func TestCreateEmptyStrut(t *testing.T) { + type EmptyStruct struct { + ID uint + } + DB.AutoMigrate(&EmptyStruct{}) + + if err := DB.Create(&EmptyStruct{}).Error; err != nil { + t.Errorf("No error should happen when creating user, but got %v", err) + } +} + func TestCreateWithExistingTimestamp(t *testing.T) { user := User{Name: "CreateUserExistingTimestamp"} diff --git a/customize_column_test.go b/customize_column_test.go index ddb536b87..5e19d6f4c 100644 --- a/customize_column_test.go +++ b/customize_column_test.go @@ -279,3 +279,68 @@ func TestBelongsToWithPartialCustomizedColumn(t *testing.T) { t.Errorf("should preload discount from coupon") } } + +type SelfReferencingUser struct { + gorm.Model + Name string + Friends []*SelfReferencingUser `gorm:"many2many:UserFriends;association_jointable_foreignkey:friend_id"` +} + +func TestSelfReferencingMany2ManyColumn(t *testing.T) { + DB.DropTable(&SelfReferencingUser{}, "UserFriends") + DB.AutoMigrate(&SelfReferencingUser{}) + + friend1 := SelfReferencingUser{Name: "friend1_m2m"} + if err := DB.Create(&friend1).Error; err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + friend2 := SelfReferencingUser{Name: "friend2_m2m"} + if err := DB.Create(&friend2).Error; err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + user := SelfReferencingUser{ + Name: "self_m2m", + Friends: []*SelfReferencingUser{&friend1, &friend2}, + } + + if err := DB.Create(&user).Error; err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + if DB.Model(&user).Association("Friends").Count() != 2 { + t.Errorf("Should find created friends correctly") + } + + var newUser = SelfReferencingUser{} + + if err := DB.Preload("Friends").First(&newUser, "id = ?", user.ID).Error; err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + if len(newUser.Friends) != 2 { + t.Errorf("Should preload created frineds for self reference m2m") + } + + DB.Model(&newUser).Association("Friends").Append(&SelfReferencingUser{Name: "friend3_m2m"}) + if DB.Model(&user).Association("Friends").Count() != 3 { + t.Errorf("Should find created friends correctly") + } + + DB.Model(&newUser).Association("Friends").Replace(&SelfReferencingUser{Name: "friend4_m2m"}) + if DB.Model(&user).Association("Friends").Count() != 1 { + t.Errorf("Should find created friends correctly") + } + + friend := SelfReferencingUser{} + DB.Model(&newUser).Association("Friends").Find(&friend) + if friend.Name != "friend4_m2m" { + t.Errorf("Should find created friends correctly") + } + + DB.Model(&newUser).Association("Friends").Delete(friend) + if DB.Model(&user).Association("Friends").Count() != 0 { + t.Errorf("All friends should be deleted") + } +} diff --git a/dialect.go b/dialect.go index 90b1723f0..5f6439c1e 100644 --- a/dialect.go +++ b/dialect.go @@ -33,6 +33,8 @@ type Dialect interface { HasTable(tableName string) bool // HasColumn check has column or not HasColumn(tableName string, columnName string) bool + // ModifyColumn modify column's type + ModifyColumn(tableName string, columnName string, typ string) error // LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case LimitAndOffsetSQL(limit, offset interface{}) string @@ -40,6 +42,8 @@ type Dialect interface { SelectFromDummyTable() string // LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING` LastInsertIDReturningSuffix(tableName, columnName string) string + // DefaultValueStr + DefaultValueStr() string // BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference BuildKeyName(kind, tableName string, fields ...string) string @@ -90,14 +94,16 @@ var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fiel } // Get scanner's real value - var getScannerValue func(reflect.Value) - getScannerValue = func(value reflect.Value) { - fieldValue = value - if _, isScanner := reflect.New(fieldValue.Type()).Interface().(sql.Scanner); isScanner && fieldValue.Kind() == reflect.Struct { - getScannerValue(fieldValue.Field(0)) + if dataType == "" { + var getScannerValue func(reflect.Value) + getScannerValue = func(value reflect.Value) { + fieldValue = value + if _, isScanner := reflect.New(fieldValue.Type()).Interface().(sql.Scanner); isScanner && fieldValue.Kind() == reflect.Struct { + getScannerValue(fieldValue.Field(0)) + } } + getScannerValue(fieldValue) } - getScannerValue(fieldValue) // Default Size if num, ok := field.TagSettings["SIZE"]; ok { diff --git a/dialect_common.go b/dialect_common.go index 30f035a51..b9f0c7da7 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -38,6 +38,13 @@ func (commonDialect) Quote(key string) string { return fmt.Sprintf(`"%s"`, key) } +func (s *commonDialect) fieldCanAutoIncrement(field *StructField) bool { + if value, ok := field.TagSettings["AUTO_INCREMENT"]; ok { + return strings.ToLower(value) != "false" + } + return field.IsPrimaryKey +} + func (s *commonDialect) DataTypeOf(field *StructField) string { var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s) @@ -46,13 +53,13 @@ func (s *commonDialect) DataTypeOf(field *StructField) string { case reflect.Bool: sqlType = "BOOLEAN" case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok { + if s.fieldCanAutoIncrement(field) { sqlType = "INTEGER AUTO_INCREMENT" } else { sqlType = "INTEGER" } case reflect.Int64, reflect.Uint64: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok { + if s.fieldCanAutoIncrement(field) { sqlType = "BIGINT AUTO_INCREMENT" } else { sqlType = "BIGINT" @@ -120,6 +127,11 @@ func (s commonDialect) HasColumn(tableName string, columnName string) bool { return count > 0 } +func (s commonDialect) ModifyColumn(tableName string, columnName string, typ string) error { + _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", tableName, columnName, typ)) + return err +} + func (s commonDialect) CurrentDatabase() (name string) { s.db.QueryRow("SELECT DATABASE()").Scan(&name) return @@ -147,6 +159,10 @@ func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) s return "" } +func (commonDialect) DefaultValueStr() string { + return "DEFAULT VALUES" +} + // BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference func (DefaultForeignKeyNamer) BuildKeyName(kind, tableName string, fields ...string) string { keyName := fmt.Sprintf("%s_%s_%s", kind, tableName, strings.Join(fields, "_")) diff --git a/dialect_mysql.go b/dialect_mysql.go index f4858e10f..b162badee 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -44,42 +44,42 @@ func (s *mysql) DataTypeOf(field *StructField) string { case reflect.Bool: sqlType = "boolean" case reflect.Int8: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "tinyint AUTO_INCREMENT" } else { sqlType = "tinyint" } case reflect.Int, reflect.Int16, reflect.Int32: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "int AUTO_INCREMENT" } else { sqlType = "int" } case reflect.Uint8: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "tinyint unsigned AUTO_INCREMENT" } else { sqlType = "tinyint unsigned" } case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "int unsigned AUTO_INCREMENT" } else { sqlType = "int unsigned" } case reflect.Int64: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "bigint AUTO_INCREMENT" } else { sqlType = "bigint" } case reflect.Uint64: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "bigint unsigned AUTO_INCREMENT" } else { @@ -95,10 +95,15 @@ func (s *mysql) DataTypeOf(field *StructField) string { } case reflect.Struct: if _, ok := dataValue.Interface().(time.Time); ok { + precision := "" + if p, ok := field.TagSettings["PRECISION"]; ok { + precision = fmt.Sprintf("(%s)", p) + } + if _, ok := field.TagSettings["NOT NULL"]; ok { - sqlType = "timestamp" + sqlType = fmt.Sprintf("timestamp%v", precision) } else { - sqlType = "timestamp NULL" + sqlType = fmt.Sprintf("timestamp%v NULL", precision) } } default: @@ -127,6 +132,11 @@ func (s mysql) RemoveIndex(tableName string, indexName string) error { return err } +func (s mysql) ModifyColumn(tableName string, columnName string, typ string) error { + _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v MODIFY COLUMN %v %v", tableName, columnName, typ)) + return err +} + func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) { if limit != nil { if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 { @@ -175,3 +185,7 @@ func (s mysql) BuildKeyName(kind, tableName string, fields ...string) string { return fmt.Sprintf("%s%x", string(destRunes), bs) } + +func (mysql) DefaultValueStr() string { + return "VALUES()" +} diff --git a/dialect_postgres.go b/dialect_postgres.go index 6fdf4df10..c44c6a5b5 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -1,6 +1,7 @@ package gorm import ( + "encoding/json" "fmt" "reflect" "strings" @@ -32,14 +33,14 @@ func (s *postgres) DataTypeOf(field *StructField) string { case reflect.Bool: sqlType = "boolean" case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uintptr: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "serial" } else { sqlType = "integer" } case reflect.Int64, reflect.Uint32, reflect.Uint64: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "bigserial" } else { @@ -68,9 +69,14 @@ func (s *postgres) DataTypeOf(field *StructField) string { default: if IsByteArrayOrSlice(dataValue) { sqlType = "bytea" + if isUUID(dataValue) { sqlType = "uuid" } + + if isJSON(dataValue) { + sqlType = "jsonb" + } } } } @@ -130,3 +136,8 @@ func isUUID(value reflect.Value) bool { lower := strings.ToLower(typename) return "uuid" == lower || "guid" == lower } + +func isJSON(value reflect.Value) bool { + _, ok := value.Interface().(json.RawMessage) + return ok +} diff --git a/dialect_sqlite3.go b/dialect_sqlite3.go index de9c05cb7..f26f6be30 100644 --- a/dialect_sqlite3.go +++ b/dialect_sqlite3.go @@ -28,14 +28,14 @@ func (s *sqlite3) DataTypeOf(field *StructField) string { case reflect.Bool: sqlType = "bool" case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "integer primary key autoincrement" } else { sqlType = "integer" } case reflect.Int64, reflect.Uint64: - if field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "integer primary key autoincrement" } else { diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index a4f8e87c9..e06064652 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -54,7 +54,7 @@ func (mssql) BindVar(i int) string { } func (mssql) Quote(key string) string { - return fmt.Sprintf(`"%s"`, key) + return fmt.Sprintf(`[%s]`, key) } func (s *mssql) DataTypeOf(field *gorm.StructField) string { @@ -65,14 +65,14 @@ func (s *mssql) DataTypeOf(field *gorm.StructField) string { case reflect.Bool: sqlType = "bit" case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "int IDENTITY(1,1)" } else { sqlType = "int" } case reflect.Int64, reflect.Uint64: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" sqlType = "bigint IDENTITY(1,1)" } else { @@ -111,6 +111,13 @@ func (s *mssql) DataTypeOf(field *gorm.StructField) string { return fmt.Sprintf("%v %v", sqlType, additionalType) } +func (s mssql) fieldCanAutoIncrement(field *gorm.StructField) bool { + if value, ok := field.TagSettings["AUTO_INCREMENT"]; ok { + return value != "FALSE" + } + return field.IsPrimaryKey +} + func (s mssql) HasIndex(tableName string, indexName string) bool { var count int s.db.QueryRow("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count) @@ -140,6 +147,11 @@ func (s mssql) HasColumn(tableName string, columnName string) bool { return count > 0 } +func (s mssql) ModifyColumn(tableName string, columnName string, typ string) error { + _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v %v", tableName, columnName, typ)) + return err +} + func (s mssql) CurrentDatabase() (name string) { s.db.QueryRow("SELECT DB_NAME() AS [Current Database]").Scan(&name) return @@ -171,6 +183,10 @@ func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string { return "" } +func (mssql) DefaultValueStr() string { + return "DEFAULT VALUES" +} + func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, string) { if strings.Contains(tableName, ".") { splitStrings := strings.SplitN(tableName, ".", 2) diff --git a/dialects/postgres/postgres.go b/dialects/postgres/postgres.go index b8e768913..1d0dcb60f 100644 --- a/dialects/postgres/postgres.go +++ b/dialects/postgres/postgres.go @@ -63,6 +63,9 @@ type Jsonb struct { // Value get value of Jsonb func (j Jsonb) Value() (driver.Value, error) { + if len(j.RawMessage) == 0 { + return nil, nil + } return j.MarshalJSON() } diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 000000000..79bf5fc39 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,30 @@ +version: '3' + +services: + mysql: + image: 'mysql:latest' + ports: + - 9910:3306 + environment: + - MYSQL_DATABASE=gorm + - MYSQL_USER=gorm + - MYSQL_PASSWORD=gorm + - MYSQL_RANDOM_ROOT_PASSWORD="yes" + postgres: + image: 'postgres:latest' + ports: + - 9920:5432 + environment: + - POSTGRES_USER=gorm + - POSTGRES_DB=gorm + - POSTGRES_PASSWORD=gorm + mssql: + image: 'mcmoe/mssqldocker:latest' + ports: + - 9930:1433 + environment: + - ACCEPT_EULA=Y + - SA_PASSWORD=LoremIpsum86 + - MSSQL_DB=gorm + - MSSQL_USER=gorm + - MSSQL_PASSWORD=LoremIpsum86 diff --git a/embedded_struct_test.go b/embedded_struct_test.go index 91dd05633..5f8ece573 100644 --- a/embedded_struct_test.go +++ b/embedded_struct_test.go @@ -71,3 +71,21 @@ func TestSaveAndQueryEmbeddedStruct(t *testing.T) { } } } + +func TestEmbeddedPointerTypeStruct(t *testing.T) { + type HNPost struct { + *BasePost + Upvotes int32 + } + + DB.Create(&HNPost{BasePost: &BasePost{Title: "embedded_pointer_type"}}) + + var hnPost HNPost + if err := DB.First(&hnPost, "title = ?", "embedded_pointer_type").Error; err != nil { + t.Errorf("No error should happen when find embedded pointer type, but got %v", err) + } + + if hnPost.Title != "embedded_pointer_type" { + t.Errorf("Should find correct value for embedded pointer type") + } +} diff --git a/errors.go b/errors.go index 6845188e6..da2cf13cb 100644 --- a/errors.go +++ b/errors.go @@ -21,6 +21,18 @@ var ( // Errors contains all happened errors type Errors []error +// IsRecordNotFoundError returns current error has record not found error or not +func IsRecordNotFoundError(err error) bool { + if errs, ok := err.(Errors); ok { + for _, err := range errs { + if err == ErrRecordNotFound { + return true + } + } + } + return err == ErrRecordNotFound +} + // GetErrors gets all happened errors func (errs Errors) GetErrors() []error { return errs diff --git a/join_table_handler.go b/join_table_handler.go index 2d1a50554..a036d46d2 100644 --- a/join_table_handler.go +++ b/join_table_handler.go @@ -79,41 +79,43 @@ func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, s // Table return join table's table name func (s JoinTableHandler) Table(db *DB) string { - return s.TableName + return DefaultTableNameHandler(db, s.TableName) } -func (s JoinTableHandler) getSearchMap(db *DB, sources ...interface{}) map[string]interface{} { - values := map[string]interface{}{} - +func (s JoinTableHandler) updateConditionMap(conditionMap map[string]interface{}, db *DB, joinTableSources []JoinTableSource, sources ...interface{}) { for _, source := range sources { scope := db.NewScope(source) modelType := scope.GetModelStruct().ModelType - if s.Source.ModelType == modelType { - for _, foreignKey := range s.Source.ForeignKeys { - if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { - values[foreignKey.DBName] = field.Field.Interface() - } - } - } else if s.Destination.ModelType == modelType { - for _, foreignKey := range s.Destination.ForeignKeys { - if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { - values[foreignKey.DBName] = field.Field.Interface() + for _, joinTableSource := range joinTableSources { + if joinTableSource.ModelType == modelType { + for _, foreignKey := range joinTableSource.ForeignKeys { + if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { + conditionMap[foreignKey.DBName] = field.Field.Interface() + } } + break } } } - return values } // Add create relationship in join table for source and destination func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error { - scope := db.NewScope("") - searchMap := s.getSearchMap(db, source, destination) + var ( + scope = db.NewScope("") + conditionMap = map[string]interface{}{} + ) + + // Update condition map for source + s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source}, source) + + // Update condition map for destination + s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Destination}, destination) var assignColumns, binVars, conditions []string var values []interface{} - for key, value := range searchMap { + for key, value := range conditionMap { assignColumns = append(assignColumns, scope.Quote(key)) binVars = append(binVars, `?`) conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key))) @@ -141,12 +143,15 @@ func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source // Delete delete relationship in join table for sources func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error { var ( - scope = db.NewScope(nil) - conditions []string - values []interface{} + scope = db.NewScope(nil) + conditions []string + values []interface{} + conditionMap = map[string]interface{}{} ) - for key, value := range s.getSearchMap(db, sources...) { + s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source, s.Destination}, sources...) + + for key, value := range conditionMap { conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key))) values = append(values, value) } diff --git a/main.go b/main.go index 16fa0b799..c26e05c8c 100644 --- a/main.go +++ b/main.go @@ -177,6 +177,15 @@ func (s *DB) QueryExpr() *expr { return Expr(scope.SQL, scope.SQLVars...) } +// SubQuery returns the query as sub query +func (s *DB) SubQuery() *expr { + scope := s.NewScope(s.Value) + scope.InstanceSet("skip_bindvar", true) + scope.prepareQuerySQL() + + return Expr(fmt.Sprintf("(%v)", scope.SQL), scope.SQLVars...) +} + // Where return a new relation, filter records with given conditions, accepts `map`, `struct` or `string` as conditions, refer http://jinzhu.github.io/gorm/crud.html#query func (s *DB) Where(query interface{}, args ...interface{}) *DB { return s.clone().search.Where(query, args...).db @@ -274,15 +283,22 @@ func (s *DB) Assign(attrs ...interface{}) *DB { // First find first record that match given conditions, order by primary key func (s *DB) First(out interface{}, where ...interface{}) *DB { - newScope := s.clone().NewScope(out) + newScope := s.NewScope(out) newScope.Search.Limit(1) return newScope.Set("gorm:order_by_primary_key", "ASC"). inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db } +// Take return a record that match given conditions, the order will depend on the database implementation +func (s *DB) Take(out interface{}, where ...interface{}) *DB { + newScope := s.NewScope(out) + newScope.Search.Limit(1) + return newScope.inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db +} + // Last find last record that match given conditions, order by primary key func (s *DB) Last(out interface{}, where ...interface{}) *DB { - newScope := s.clone().NewScope(out) + newScope := s.NewScope(out) newScope.Search.Limit(1) return newScope.Set("gorm:order_by_primary_key", "DESC"). inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db @@ -290,12 +306,12 @@ func (s *DB) Last(out interface{}, where ...interface{}) *DB { // Find find records that match given conditions func (s *DB) Find(out interface{}, where ...interface{}) *DB { - return s.clone().NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db + return s.NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db } // Scan scan value to a struct func (s *DB) Scan(dest interface{}) *DB { - return s.clone().NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db + return s.NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db } // Row return `*sql.Row` with given conditions @@ -311,8 +327,8 @@ func (s *DB) Rows() (*sql.Rows, error) { // ScanRows scan `*sql.Rows` to give struct func (s *DB) ScanRows(rows *sql.Rows, result interface{}) error { var ( - clone = s.clone() - scope = clone.NewScope(result) + scope = s.NewScope(result) + clone = scope.db columns, err = rows.Columns() ) @@ -337,7 +353,7 @@ func (s *DB) Count(value interface{}) *DB { // Related get related associations func (s *DB) Related(value interface{}, foreignKeys ...string) *DB { - return s.clone().NewScope(s.Value).related(value, foreignKeys...).db + return s.NewScope(s.Value).related(value, foreignKeys...).db } // FirstOrInit find first matched record or initialize a new one with given conditions (only works with struct, map conditions) @@ -377,7 +393,7 @@ func (s *DB) Update(attrs ...interface{}) *DB { // Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB { - return s.clone().NewScope(s.Value). + return s.NewScope(s.Value). Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0). InstanceSet("gorm:update_interface", values). callCallbacks(s.parent.callbacks.updates).db @@ -390,7 +406,7 @@ func (s *DB) UpdateColumn(attrs ...interface{}) *DB { // UpdateColumns update attributes without callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update func (s *DB) UpdateColumns(values interface{}) *DB { - return s.clone().NewScope(s.Value). + return s.NewScope(s.Value). Set("gorm:update_column", true). Set("gorm:save_associations", false). InstanceSet("gorm:update_interface", values). @@ -399,7 +415,7 @@ func (s *DB) UpdateColumns(values interface{}) *DB { // Save update value in database, if the value doesn't have primary key, will insert it func (s *DB) Save(value interface{}) *DB { - scope := s.clone().NewScope(value) + scope := s.NewScope(value) if !scope.PrimaryKeyZero() { newDB := scope.callCallbacks(s.parent.callbacks.updates).db if newDB.Error == nil && newDB.RowsAffected == 0 { @@ -412,13 +428,13 @@ func (s *DB) Save(value interface{}) *DB { // Create insert the value into database func (s *DB) Create(value interface{}) *DB { - scope := s.clone().NewScope(value) + scope := s.NewScope(value) return scope.callCallbacks(s.parent.callbacks.creates).db } // Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition func (s *DB) Delete(value interface{}, where ...interface{}) *DB { - return s.clone().NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db + return s.NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db } // Raw use raw sql as conditions, won't run it unless invoked by other methods @@ -429,8 +445,8 @@ func (s *DB) Raw(sql string, values ...interface{}) *DB { // Exec execute raw sql func (s *DB) Exec(sql string, values ...interface{}) *DB { - scope := s.clone().NewScope(nil) - generatedSQL := scope.buildWhereCondition(map[string]interface{}{"query": sql, "args": values}) + scope := s.NewScope(nil) + generatedSQL := scope.buildCondition(map[string]interface{}{"query": sql, "args": values}, true) generatedSQL = strings.TrimSuffix(strings.TrimPrefix(generatedSQL, "("), ")") scope.Raw(generatedSQL) return scope.Exec().db @@ -495,7 +511,7 @@ func (s *DB) Rollback() *DB { // NewRecord check if value's primary key is blank func (s *DB) NewRecord(value interface{}) bool { - return s.clone().NewScope(value).PrimaryKeyZero() + return s.NewScope(value).PrimaryKeyZero() } // RecordNotFound check if returning ErrRecordNotFound error @@ -544,7 +560,7 @@ func (s *DB) DropTableIfExists(values ...interface{}) *DB { // HasTable check has table or not func (s *DB) HasTable(value interface{}) bool { var ( - scope = s.clone().NewScope(value) + scope = s.NewScope(value) tableName string ) @@ -570,14 +586,14 @@ func (s *DB) AutoMigrate(values ...interface{}) *DB { // ModifyColumn modify column to type func (s *DB) ModifyColumn(column string, typ string) *DB { - scope := s.clone().NewScope(s.Value) + scope := s.NewScope(s.Value) scope.modifyColumn(column, typ) return scope.db } // DropColumn drop a column func (s *DB) DropColumn(column string) *DB { - scope := s.clone().NewScope(s.Value) + scope := s.NewScope(s.Value) scope.dropColumn(column) return scope.db } @@ -598,7 +614,7 @@ func (s *DB) AddUniqueIndex(indexName string, columns ...string) *DB { // RemoveIndex remove index with name func (s *DB) RemoveIndex(indexName string) *DB { - scope := s.clone().NewScope(s.Value) + scope := s.NewScope(s.Value) scope.removeIndex(indexName) return scope.db } @@ -606,11 +622,19 @@ func (s *DB) RemoveIndex(indexName string) *DB { // AddForeignKey Add foreign key to the given scope, e.g: // db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT") func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB { - scope := s.clone().NewScope(s.Value) + scope := s.NewScope(s.Value) scope.addForeignKey(field, dest, onDelete, onUpdate) return scope.db } +// RemoveForeignKey Remove foreign key from the given scope, e.g: +// db.Model(&User{}).RemoveForeignKey("city_id", "cities(id)") +func (s *DB) RemoveForeignKey(field string, dest string) *DB { + scope := s.clone().NewScope(s.Value) + scope.removeForeignKey(field, dest) + return scope.db +} + // Association start `Association Mode` to handler relations things easir in that mode, refer: https://jinzhu.github.io/gorm/associations.html#association-mode func (s *DB) Association(column string) *Association { var err error diff --git a/main_test.go b/main_test.go index 34f96a86f..66c46af05 100644 --- a/main_test.go +++ b/main_test.go @@ -36,27 +36,20 @@ func init() { } func OpenTestConnection() (db *gorm.DB, err error) { + dbDSN := os.Getenv("GORM_DSN") switch os.Getenv("GORM_DIALECT") { case "mysql": - // CREATE USER 'gorm'@'localhost' IDENTIFIED BY 'gorm'; - // CREATE DATABASE gorm; - // GRANT ALL ON gorm.* TO 'gorm'@'localhost'; fmt.Println("testing mysql...") - dbhost := os.Getenv("GORM_DBADDRESS") - if dbhost != "" { - dbhost = fmt.Sprintf("tcp(%v)", dbhost) + if dbDSN == "" { + dbDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" } - db, err = gorm.Open("mysql", fmt.Sprintf("gorm:gorm@%v/gorm?charset=utf8&parseTime=True", dbhost)) + db, err = gorm.Open("mysql", dbDSN) case "postgres": fmt.Println("testing postgres...") - dbhost := os.Getenv("GORM_DBHOST") - if dbhost != "" { - dbhost = fmt.Sprintf("host=%v ", dbhost) + if dbDSN == "" { + dbDSN = "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable" } - db, err = gorm.Open("postgres", fmt.Sprintf("%vuser=gorm password=gorm DB.name=gorm sslmode=disable", dbhost)) - case "foundation": - fmt.Println("testing foundation...") - db, err = gorm.Open("foundation", "dbname=gorm port=15432 sslmode=disable") + db, err = gorm.Open("postgres", dbDSN) case "mssql": // CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86'; // CREATE DATABASE gorm; @@ -64,7 +57,10 @@ func OpenTestConnection() (db *gorm.DB, err error) { // CREATE USER gorm FROM LOGIN gorm; // sp_changedbowner 'gorm'; fmt.Println("testing mssql...") - db, err = gorm.Open("mssql", "sqlserver://gorm:LoremIpsum86@localhost:1433?database=gorm") + if dbDSN == "" { + dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" + } + db, err = gorm.Open("mssql", dbDSN) default: fmt.Println("testing sqlite3...") db, err = gorm.Open("sqlite3", filepath.Join(os.TempDir(), "gorm.db")) @@ -72,8 +68,10 @@ func OpenTestConnection() (db *gorm.DB, err error) { // db.SetLogger(Logger{log.New(os.Stdout, "\r\n", 0)}) // db.SetLogger(log.New(os.Stdout, "\r\n", 0)) - if os.Getenv("DEBUG") == "true" { + if debug := os.Getenv("DEBUG"); debug == "true" { db.LogMode(true) + } else if debug == "false" { + db.LogMode(false) } db.DB().SetMaxIdleConns(10) @@ -633,6 +631,47 @@ func TestQueryBuilderSubselectInWhere(t *testing.T) { } } +func TestQueryBuilderRawQueryWithSubquery(t *testing.T) { + user := User{Name: "subquery_test_user1", Age: 10} + DB.Save(&user) + user = User{Name: "subquery_test_user2", Age: 11} + DB.Save(&user) + user = User{Name: "subquery_test_user3", Age: 12} + DB.Save(&user) + + var count int + err := DB.Raw("select count(*) from (?) tmp", + DB.Table("users"). + Select("name"). + Where("age >= ? and name in (?)", 10, []string{"subquery_test_user1", "subquery_test_user2"}). + Group("name"). + QueryExpr(), + ).Count(&count).Error + + if err != nil { + t.Errorf("Expected to get no errors, but got %v", err) + } + if count != 2 { + t.Errorf("Row count must be 2, instead got %d", count) + } + + err = DB.Raw("select count(*) from (?) tmp", + DB.Table("users"). + Select("name"). + Where("name LIKE ?", "subquery_test%"). + Not("age <= ?", 10).Not("name in (?)", []string{"subquery_test_user1", "subquery_test_user2"}). + Group("name"). + QueryExpr(), + ).Count(&count).Error + + if err != nil { + t.Errorf("Expected to get no errors, but got %v", err) + } + if count != 1 { + t.Errorf("Row count must be 1, instead got %d", count) + } +} + func TestQueryBuilderSubselectInHaving(t *testing.T) { user := User{Name: "query_expr_having_ruser1", Email: "root@user1.com", Age: 64} DB.Save(&user) diff --git a/migration_test.go b/migration_test.go index 3f3a5c8fa..7c6944856 100644 --- a/migration_test.go +++ b/migration_test.go @@ -7,6 +7,7 @@ import ( "fmt" "os" "reflect" + "strconv" "testing" "time" @@ -33,6 +34,7 @@ type User struct { CompanyID *int Company Company Role Role + Password EncryptedData PasswordHash []byte IgnoreMe int64 `sql:"-"` IgnoreStringSlice []string `sql:"-"` @@ -116,6 +118,31 @@ type Company struct { Owner *User `sql:"-"` } +type EncryptedData []byte + +func (data *EncryptedData) Scan(value interface{}) error { + if b, ok := value.([]byte); ok { + if len(b) < 3 || b[0] != '*' || b[1] != '*' || b[2] != '*' { + return errors.New("Too short") + } + + *data = b[3:] + return nil + } + + return errors.New("Bytes expected") +} + +func (data EncryptedData) Value() (driver.Value, error) { + if len(data) > 0 && data[0] == 'x' { + //needed to test failures + return nil, errors.New("Should not start with 'x'") + } + + //prepend asterisks + return append([]byte("***"), data...), nil +} + type Role struct { Name string `gorm:"size:256"` } @@ -142,6 +169,8 @@ type Num int64 func (i *Num) Scan(src interface{}) error { switch s := src.(type) { case []byte: + n, _ := strconv.Atoi(string(s)) + *i = Num(n) case int64: *i = Num(s) default: @@ -435,10 +464,7 @@ func TestMultipleIndexes(t *testing.T) { } func TestModifyColumnType(t *testing.T) { - dialect := os.Getenv("GORM_DIALECT") - if dialect != "postgres" && - dialect != "mysql" && - dialect != "mssql" { + if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" && dialect != "mysql" && dialect != "mssql" { t.Skip("Skipping this because only postgres, mysql and mssql support altering a column type") } diff --git a/model_struct.go b/model_struct.go index 315028c46..f571e2e86 100644 --- a/model_struct.go +++ b/model_struct.go @@ -249,11 +249,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct { ) if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" { - foreignKeys = strings.Split(field.TagSettings["FOREIGNKEY"], ",") + foreignKeys = strings.Split(foreignKey, ",") } - if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { - associationForeignKeys = strings.Split(field.TagSettings["ASSOCIATIONFOREIGNKEY"], ",") + if foreignKey := field.TagSettings["ASSOCIATION_FOREIGNKEY"]; foreignKey != "" { + associationForeignKeys = strings.Split(foreignKey, ",") + } else if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { + associationForeignKeys = strings.Split(foreignKey, ",") } for elemType.Kind() == reflect.Slice || elemType.Kind() == reflect.Ptr { @@ -264,37 +266,65 @@ func (scope *Scope) GetModelStruct() *ModelStruct { if many2many := field.TagSettings["MANY2MANY"]; many2many != "" { relationship.Kind = "many_to_many" - // if no foreign keys defined with tag - if len(foreignKeys) == 0 { - for _, field := range modelStruct.PrimaryFields { - foreignKeys = append(foreignKeys, field.DBName) + { // Foreign Keys for Source + joinTableDBNames := []string{} + + if foreignKey := field.TagSettings["JOINTABLE_FOREIGNKEY"]; foreignKey != "" { + joinTableDBNames = strings.Split(foreignKey, ",") } - } - for _, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil { - // source foreign keys (db names) - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.DBName) - // join table foreign keys for source - joinTableDBName := ToDBName(reflectType.Name()) + "_" + foreignField.DBName - relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName) + // if no foreign keys defined with tag + if len(foreignKeys) == 0 { + for _, field := range modelStruct.PrimaryFields { + foreignKeys = append(foreignKeys, field.DBName) + } } - } - // if no association foreign keys defined with tag - if len(associationForeignKeys) == 0 { - for _, field := range toScope.PrimaryFields() { - associationForeignKeys = append(associationForeignKeys, field.DBName) + for idx, foreignKey := range foreignKeys { + if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil { + // source foreign keys (db names) + relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.DBName) + + // setup join table foreign keys for source + if len(joinTableDBNames) > idx { + // if defined join table's foreign key + relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBNames[idx]) + } else { + defaultJointableForeignKey := ToDBName(reflectType.Name()) + "_" + foreignField.DBName + relationship.ForeignDBNames = append(relationship.ForeignDBNames, defaultJointableForeignKey) + } + } } } - for _, name := range associationForeignKeys { - if field, ok := toScope.FieldByName(name); ok { - // association foreign keys (db names) - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName) - // join table foreign keys for association - joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) + { // Foreign Keys for Association (Destination) + associationJoinTableDBNames := []string{} + + if foreignKey := field.TagSettings["ASSOCIATION_JOINTABLE_FOREIGNKEY"]; foreignKey != "" { + associationJoinTableDBNames = strings.Split(foreignKey, ",") + } + + // if no association foreign keys defined with tag + if len(associationForeignKeys) == 0 { + for _, field := range toScope.PrimaryFields() { + associationForeignKeys = append(associationForeignKeys, field.DBName) + } + } + + for idx, name := range associationForeignKeys { + if field, ok := toScope.FieldByName(name); ok { + // association foreign keys (db names) + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName) + + // setup join table foreign keys for association + if len(associationJoinTableDBNames) > idx { + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationJoinTableDBNames[idx]) + } else { + // join table foreign keys for association + joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) + } + } } } @@ -399,11 +429,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct { ) if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" { - tagForeignKeys = strings.Split(field.TagSettings["FOREIGNKEY"], ",") + tagForeignKeys = strings.Split(foreignKey, ",") } - if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { - tagAssociationForeignKeys = strings.Split(field.TagSettings["ASSOCIATIONFOREIGNKEY"], ",") + if foreignKey := field.TagSettings["ASSOCIATION_FOREIGNKEY"]; foreignKey != "" { + tagAssociationForeignKeys = strings.Split(foreignKey, ",") + } else if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { + tagAssociationForeignKeys = strings.Split(foreignKey, ",") } if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { diff --git a/preload_test.go b/preload_test.go index 1b89e77be..311ad0be1 100644 --- a/preload_test.go +++ b/preload_test.go @@ -1627,6 +1627,48 @@ func TestPrefixedPreloadDuplication(t *testing.T) { } } +func TestPreloadManyToManyCallbacks(t *testing.T) { + type ( + Level2 struct { + ID uint + Name string + } + Level1 struct { + ID uint + Name string + Level2s []Level2 `gorm:"many2many:level1_level2s;AssociationForeignKey:ID;ForeignKey:ID"` + } + ) + + DB.DropTableIfExists("level1_level2s") + DB.DropTableIfExists(new(Level1)) + DB.DropTableIfExists(new(Level2)) + + if err := DB.AutoMigrate(new(Level1), new(Level2)).Error; err != nil { + t.Error(err) + } + + lvl := Level1{ + Name: "l1", + Level2s: []Level2{ + Level2{Name: "l2-1"}, Level2{Name: "l2-2"}, + }, + } + DB.Save(&lvl) + + called := 0 + + DB.Callback().Query().After("gorm:query").Register("TestPreloadManyToManyCallbacks", func(scope *gorm.Scope) { + called = called + 1 + }) + + DB.Preload("Level2s").First(&Level1{}, "id = ?", lvl.ID) + + if called != 3 { + t.Errorf("Wanted callback to be called 3 times but got %d", called) + } +} + func toJSONString(v interface{}) []byte { r, _ := json.MarshalIndent(v, "", " ") return r diff --git a/query_test.go b/query_test.go index def84e044..fac7d4d89 100644 --- a/query_test.go +++ b/query_test.go @@ -87,6 +87,37 @@ func TestUIntPrimaryKey(t *testing.T) { } } +func TestCustomizedTypePrimaryKey(t *testing.T) { + type ID uint + type CustomizedTypePrimaryKey struct { + ID ID + Name string + } + + DB.AutoMigrate(&CustomizedTypePrimaryKey{}) + + p1 := CustomizedTypePrimaryKey{Name: "p1"} + p2 := CustomizedTypePrimaryKey{Name: "p2"} + p3 := CustomizedTypePrimaryKey{Name: "p3"} + DB.Create(&p1) + DB.Create(&p2) + DB.Create(&p3) + + var p CustomizedTypePrimaryKey + + if err := DB.First(&p, p2.ID).Error; err == nil { + t.Errorf("Should return error for invalid query condition") + } + + if err := DB.First(&p, "id = ?", p2.ID).Error; err != nil { + t.Errorf("No error should happen when querying with customized type for primary key, got err %v", err) + } + + if p.Name != "p2" { + t.Errorf("Should find correct value when querying with customized type for primary key") + } +} + func TestStringPrimaryKeyForNumericValueStartingWithZero(t *testing.T) { type AddressByZipCode struct { ZipCode string `gorm:"primary_key"` @@ -99,7 +130,7 @@ func TestStringPrimaryKeyForNumericValueStartingWithZero(t *testing.T) { var address AddressByZipCode DB.First(&address, "00501") if address.ZipCode != "00501" { - t.Errorf("Fetch a record from with a string primary key for a numeric value starting with zero should work, but failed") + t.Errorf("Fetch a record from with a string primary key for a numeric value starting with zero should work, but failed, zip code is %v", address.ZipCode) } } @@ -191,6 +222,36 @@ func TestSearchWithPlainSQL(t *testing.T) { } } +func TestSearchWithTwoDimensionalArray(t *testing.T) { + var users []User + user1 := User{Name: "2DSearchUser1", Age: 1, Birthday: parseTime("2000-1-1")} + user2 := User{Name: "2DSearchUser2", Age: 10, Birthday: parseTime("2010-1-1")} + user3 := User{Name: "2DSearchUser3", Age: 20, Birthday: parseTime("2020-1-1")} + DB.Create(&user1) + DB.Create(&user2) + DB.Create(&user3) + + if dialect := DB.Dialect().GetName(); dialect == "mysql" || dialect == "postgres" { + if err := DB.Where("(name, age) IN (?)", [][]interface{}{{"2DSearchUser1", 1}, {"2DSearchUser2", 10}}).Find(&users).Error; err != nil { + t.Errorf("No error should happen when query with 2D array, but got %v", err) + + if len(users) != 2 { + t.Errorf("Should find 2 users with 2D array, but got %v", len(users)) + } + } + } + + if dialect := DB.Dialect().GetName(); dialect == "mssql" { + if err := DB.Joins("JOIN (VALUES ?) AS x (col1, col2) ON x.col1 = name AND x.col2 = age", [][]interface{}{{"2DSearchUser1", 1}, {"2DSearchUser2", 10}}).Find(&users).Error; err != nil { + t.Errorf("No error should happen when query with 2D array, but got %v", err) + + if len(users) != 2 { + t.Errorf("Should find 2 users with 2D array, but got %v", len(users)) + } + } + } +} + func TestSearchWithStruct(t *testing.T) { user1 := User{Name: "StructSearchUser1", Age: 1, Birthday: parseTime("2000-1-1")} user2 := User{Name: "StructSearchUser2", Age: 10, Birthday: parseTime("2010-1-1")} @@ -389,7 +450,7 @@ func TestOffset(t *testing.T) { DB.Save(&User{Name: fmt.Sprintf("OffsetUser%v", i)}) } var users1, users2, users3, users4 []User - DB.Limit(100).Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4) + DB.Limit(100).Where("name like ?", "OffsetUser%").Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4) if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) { t.Errorf("Offset should work") @@ -430,6 +491,15 @@ func TestCount(t *testing.T) { if count1 != 1 || count2 != 3 { t.Errorf("Multiple count in chain") } + + var count3 int + if err := DB.Model(&User{}).Where("name in (?)", []string{user2.Name, user2.Name, user3.Name}).Group("id").Count(&count3).Error; err != nil { + t.Errorf("Not error should happen, but got %v", err) + } + + if count3 != 2 { + t.Errorf("Should get correct count, but got %v", count3) + } } func TestNot(t *testing.T) { @@ -665,3 +735,39 @@ func TestSelectWithArrayInput(t *testing.T) { t.Errorf("Should have selected both age and name") } } + +func TestPluckWithSelect(t *testing.T) { + var ( + user = User{Name: "matematik7_pluck_with_select", Age: 25} + combinedName = fmt.Sprintf("%v%v", user.Name, user.Age) + combineUserAgeSQL = fmt.Sprintf("concat(%v, %v)", DB.Dialect().Quote("name"), DB.Dialect().Quote("age")) + ) + + if dialect := DB.Dialect().GetName(); dialect == "sqlite3" { + combineUserAgeSQL = fmt.Sprintf("(%v || %v)", DB.Dialect().Quote("name"), DB.Dialect().Quote("age")) + } + + DB.Save(&user) + + selectStr := combineUserAgeSQL + " as user_age" + var userAges []string + err := DB.Model(&User{}).Where("age = ?", 25).Select(selectStr).Pluck("user_age", &userAges).Error + if err != nil { + t.Error(err) + } + + if len(userAges) != 1 || userAges[0] != combinedName { + t.Errorf("Should correctly pluck with select, got: %s", userAges) + } + + selectStr = combineUserAgeSQL + fmt.Sprintf(" as %v", DB.Dialect().Quote("user_age")) + userAges = userAges[:0] + err = DB.Model(&User{}).Where("age = ?", 25).Select(selectStr).Pluck("user_age", &userAges).Error + if err != nil { + t.Error(err) + } + + if len(userAges) != 1 || userAges[0] != combinedName { + t.Errorf("Should correctly pluck with select, got: %s", userAges) + } +} diff --git a/scope.go b/scope.go index c447d8a0a..150ac7100 100644 --- a/scope.go +++ b/scope.go @@ -1,16 +1,15 @@ package gorm import ( + "bytes" "database/sql" "database/sql/driver" "errors" "fmt" + "reflect" "regexp" - "strconv" "strings" "time" - - "reflect" ) // Scope contain current operation's information when you perform any operation on the database @@ -116,6 +115,9 @@ func (scope *Scope) Fields() []*Field { if isStruct { fieldValue := indirectScopeValue for _, name := range structField.Names { + if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { + fieldValue.Set(reflect.New(fieldValue.Type().Elem())) + } fieldValue = reflect.Indirect(fieldValue).FieldByName(name) } fields = append(fields, &Field{StructField: structField, Field: fieldValue, IsBlank: isBlank(fieldValue)}) @@ -460,7 +462,7 @@ func (scope *Scope) callMethod(methodName string, reflectValue reflect.Value) { var ( columnRegexp = regexp.MustCompile("^[a-zA-Z\\d]+(\\.[a-zA-Z\\d]+)*$") // only match string like `name`, `users.name` isNumberRegexp = regexp.MustCompile("^\\s*\\d+\\s*$") // match if string is number - comparisonRegexp = regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS|IN) ") + comparisonRegexp = regexp.MustCompile("(?i) (=|<>|(>|<)(=?)|LIKE|IS|IN) ") countingQueryRegexp = regexp.MustCompile("(?i)^count(.+)$") ) @@ -521,134 +523,143 @@ func (scope *Scope) primaryCondition(value interface{}) string { return fmt.Sprintf("(%v.%v = %v)", scope.QuotedTableName(), scope.Quote(scope.PrimaryKey()), value) } -func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str string) { +func (scope *Scope) buildCondition(clause map[string]interface{}, include bool) (str string) { + var ( + quotedTableName = scope.QuotedTableName() + quotedPrimaryKey = scope.Quote(scope.PrimaryKey()) + equalSQL = "=" + inSQL = "IN" + ) + + // If building not conditions + if !include { + equalSQL = "<>" + inSQL = "NOT IN" + } + switch value := clause["query"].(type) { + case sql.NullInt64: + return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, value.Int64) + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, value) + case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}: + if !include && reflect.ValueOf(value).Len() == 0 { + return + } + str = fmt.Sprintf("(%v.%v %s (?))", quotedTableName, quotedPrimaryKey, inSQL) + clause["args"] = []interface{}{value} case string: if isNumberRegexp.MatchString(value) { - return scope.primaryCondition(scope.AddToVars(value)) - } else if value != "" { - str = fmt.Sprintf("(%v)", value) + return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, scope.AddToVars(value)) + } + + if value != "" { + if !include { + if comparisonRegexp.MatchString(value) { + str = fmt.Sprintf("NOT (%v)", value) + } else { + str = fmt.Sprintf("(%v.%v NOT IN (?))", quotedTableName, scope.Quote(value)) + } + } else { + str = fmt.Sprintf("(%v)", value) + } } - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64: - return scope.primaryCondition(scope.AddToVars(value)) - case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}: - str = fmt.Sprintf("(%v.%v IN (?))", scope.QuotedTableName(), scope.Quote(scope.PrimaryKey())) - clause["args"] = []interface{}{value} case map[string]interface{}: var sqls []string for key, value := range value { if value != nil { - sqls = append(sqls, fmt.Sprintf("(%v.%v = %v)", scope.QuotedTableName(), scope.Quote(key), scope.AddToVars(value))) + sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", quotedTableName, scope.Quote(key), equalSQL, scope.AddToVars(value))) } else { - sqls = append(sqls, fmt.Sprintf("(%v.%v IS NULL)", scope.QuotedTableName(), scope.Quote(key))) + if !include { + sqls = append(sqls, fmt.Sprintf("(%v.%v IS NOT NULL)", quotedTableName, scope.Quote(key))) + } else { + sqls = append(sqls, fmt.Sprintf("(%v.%v IS NULL)", quotedTableName, scope.Quote(key))) + } } } return strings.Join(sqls, " AND ") case interface{}: var sqls []string newScope := scope.New(value) + + if len(newScope.Fields()) == 0 { + scope.Err(fmt.Errorf("invalid query condition: %v", value)) + return + } + for _, field := range newScope.Fields() { if !field.IsIgnored && !field.IsBlank { - sqls = append(sqls, fmt.Sprintf("(%v.%v = %v)", scope.QuotedTableName(), scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) + sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", quotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface()))) } } return strings.Join(sqls, " AND ") + default: + scope.Err(fmt.Errorf("invalid query condition: %v", value)) + return } + replacements := []string{} args := clause["args"].([]interface{}) for _, arg := range args { + var err error switch reflect.ValueOf(arg).Kind() { case reflect.Slice: // For where("id in (?)", []int64{1,2}) - if bytes, ok := arg.([]byte); ok { - str = strings.Replace(str, "?", scope.AddToVars(bytes), 1) + if scanner, ok := interface{}(arg).(driver.Valuer); ok { + arg, err = scanner.Value() + replacements = append(replacements, scope.AddToVars(arg)) + } else if b, ok := arg.([]byte); ok { + replacements = append(replacements, scope.AddToVars(b)) + } else if as, ok := arg.([][]interface{}); ok { + var tempMarks []string + for _, a := range as { + var arrayMarks []string + for _, v := range a { + arrayMarks = append(arrayMarks, scope.AddToVars(v)) + } + + if len(arrayMarks) > 0 { + tempMarks = append(tempMarks, fmt.Sprintf("(%v)", strings.Join(arrayMarks, ","))) + } + } + + if len(tempMarks) > 0 { + replacements = append(replacements, strings.Join(tempMarks, ",")) + } } else if values := reflect.ValueOf(arg); values.Len() > 0 { var tempMarks []string for i := 0; i < values.Len(); i++ { tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) } - str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) + replacements = append(replacements, strings.Join(tempMarks, ",")) } else { - str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1) + replacements = append(replacements, scope.AddToVars(Expr("NULL"))) } default: if valuer, ok := interface{}(arg).(driver.Valuer); ok { - arg, _ = valuer.Value() + arg, err = valuer.Value() } - str = strings.Replace(str, "?", scope.AddToVars(arg), 1) + replacements = append(replacements, scope.AddToVars(arg)) } - } - return -} -func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) { - var notEqualSQL string - var primaryKey = scope.PrimaryKey() - - switch value := clause["query"].(type) { - case string: - if isNumberRegexp.MatchString(value) { - id, _ := strconv.Atoi(value) - return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), id) - } else if comparisonRegexp.MatchString(value) { - str = fmt.Sprintf(" NOT (%v) ", value) - notEqualSQL = fmt.Sprintf("NOT (%v)", value) - } else { - str = fmt.Sprintf("(%v.%v NOT IN (?))", scope.QuotedTableName(), scope.Quote(value)) - notEqualSQL = fmt.Sprintf("(%v.%v <> ?)", scope.QuotedTableName(), scope.Quote(value)) - } - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64: - return fmt.Sprintf("(%v.%v <> %v)", scope.QuotedTableName(), scope.Quote(primaryKey), value) - case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string: - if reflect.ValueOf(value).Len() > 0 { - str = fmt.Sprintf("(%v.%v NOT IN (?))", scope.QuotedTableName(), scope.Quote(primaryKey)) - clause["args"] = []interface{}{value} - } else { - return "" - } - case map[string]interface{}: - var sqls []string - for key, value := range value { - if value != nil { - sqls = append(sqls, fmt.Sprintf("(%v.%v <> %v)", scope.QuotedTableName(), scope.Quote(key), scope.AddToVars(value))) - } else { - sqls = append(sqls, fmt.Sprintf("(%v.%v IS NOT NULL)", scope.QuotedTableName(), scope.Quote(key))) - } - } - return strings.Join(sqls, " AND ") - case interface{}: - var sqls []string - var newScope = scope.New(value) - for _, field := range newScope.Fields() { - if !field.IsBlank { - sqls = append(sqls, fmt.Sprintf("(%v.%v <> %v)", scope.QuotedTableName(), scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) - } + if err != nil { + scope.Err(err) } - return strings.Join(sqls, " AND ") } - args := clause["args"].([]interface{}) - for _, arg := range args { - switch reflect.ValueOf(arg).Kind() { - case reflect.Slice: // For where("id in (?)", []int64{1,2}) - if bytes, ok := arg.([]byte); ok { - str = strings.Replace(str, "?", scope.AddToVars(bytes), 1) - } else if values := reflect.ValueOf(arg); values.Len() > 0 { - var tempMarks []string - for i := 0; i < values.Len(); i++ { - tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) - } - str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) - } else { - str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1) - } - default: - if scanner, ok := interface{}(arg).(driver.Valuer); ok { - arg, _ = scanner.Value() - } - str = strings.Replace(notEqualSQL, "?", scope.AddToVars(arg), 1) + buff := bytes.NewBuffer([]byte{}) + i := 0 + for _, s := range str { + if s == '?' && len(replacements) > i { + buff.WriteString(replacements[i]) + i++ + } else { + buff.WriteRune(s) } } + + str = buff.String() + return } @@ -661,6 +672,7 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) } args := clause["args"].([]interface{}) + replacements := []string{} for _, arg := range args { switch reflect.ValueOf(arg).Kind() { case reflect.Slice: @@ -669,14 +681,28 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) for i := 0; i < values.Len(); i++ { tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) } - str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) + replacements = append(replacements, strings.Join(tempMarks, ",")) default: if valuer, ok := interface{}(arg).(driver.Valuer); ok { arg, _ = valuer.Value() } - str = strings.Replace(str, "?", scope.AddToVars(arg), 1) + replacements = append(replacements, scope.AddToVars(arg)) + } + } + + buff := bytes.NewBuffer([]byte{}) + i := 0 + for pos := range str { + if str[pos] == '?' { + buff.WriteString(replacements[i]) + i++ + } else { + buff.WriteByte(str[pos]) } } + + str = buff.String() + return } @@ -700,19 +726,19 @@ func (scope *Scope) whereSQL() (sql string) { } for _, clause := range scope.Search.whereConditions { - if sql := scope.buildWhereCondition(clause); sql != "" { + if sql := scope.buildCondition(clause, true); sql != "" { andConditions = append(andConditions, sql) } } for _, clause := range scope.Search.orConditions { - if sql := scope.buildWhereCondition(clause); sql != "" { + if sql := scope.buildCondition(clause, true); sql != "" { orConditions = append(orConditions, sql) } } for _, clause := range scope.Search.notConditions { - if sql := scope.buildNotCondition(clause); sql != "" { + if sql := scope.buildCondition(clause, false); sql != "" { andConditions = append(andConditions, sql) } } @@ -786,7 +812,7 @@ func (scope *Scope) havingSQL() string { var andConditions []string for _, clause := range scope.Search.havingConditions { - if sql := scope.buildWhereCondition(clause); sql != "" { + if sql := scope.buildCondition(clause, true); sql != "" { andConditions = append(andConditions, sql) } } @@ -802,7 +828,7 @@ func (scope *Scope) havingSQL() string { func (scope *Scope) joinsSQL() string { var joinConditions []string for _, clause := range scope.Search.joinConditions { - if sql := scope.buildWhereCondition(clause); sql != "" { + if sql := scope.buildCondition(clause, true); sql != "" { joinConditions = append(joinConditions, strings.TrimSuffix(strings.TrimPrefix(sql, "("), ")")) } } @@ -924,14 +950,34 @@ func (scope *Scope) initialize() *Scope { return scope } +func (scope *Scope) isQueryForColumn(query interface{}, column string) bool { + queryStr := strings.ToLower(fmt.Sprint(query)) + if queryStr == column { + return true + } + + if strings.HasSuffix(queryStr, "as "+column) { + return true + } + + if strings.HasSuffix(queryStr, "as "+scope.Quote(column)) { + return true + } + + return false +} + func (scope *Scope) pluck(column string, value interface{}) *Scope { dest := reflect.Indirect(reflect.ValueOf(value)) - scope.Search.Select(column) if dest.Kind() != reflect.Slice { scope.Err(fmt.Errorf("results should be a slice, not %s", dest.Kind())) return scope } + if query, ok := scope.Search.selects["query"]; !ok || !scope.isQueryForColumn(query, column) { + scope.Search.Select(column) + } + rows, err := scope.rows() if scope.Err(err) == nil { defer rows.Close() @@ -950,7 +996,12 @@ func (scope *Scope) pluck(column string, value interface{}) *Scope { func (scope *Scope) count(value interface{}) *Scope { if query, ok := scope.Search.selects["query"]; !ok || !countingQueryRegexp.MatchString(fmt.Sprint(query)) { - scope.Search.Select("count(*)") + if len(scope.Search.group) != 0 { + scope.Search.Select("count(*) FROM ( SELECT count(*) as name ") + scope.Search.group += " ) AS count_table" + } else { + scope.Search.Select("count(*)") + } } scope.Search.ignoreOrderQuery = true scope.Err(scope.row().Scan(value)) @@ -993,18 +1044,6 @@ func (scope *Scope) changeableField(field *Field) bool { return true } -func (scope *Scope) shouldSaveAssociations() bool { - if saveAssociations, ok := scope.Get("gorm:save_associations"); ok { - if v, ok := saveAssociations.(bool); ok && !v { - return false - } - if v, ok := saveAssociations.(string); ok && (v != "skip") { - return false - } - } - return true && !scope.HasError() -} - func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { toScope := scope.db.NewScope(value) tx := scope.db.Set("gorm:association:source", scope.Value) @@ -1059,7 +1098,7 @@ func (scope *Scope) getTableOptions() string { if !ok { return "" } - return tableOptions.(string) + return " " + tableOptions.(string) } func (scope *Scope) createJoinTable(field *StructField) { @@ -1092,7 +1131,7 @@ func (scope *Scope) createJoinTable(field *StructField) { } } - scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v, PRIMARY KEY (%v)) %s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), strings.Join(primaryKeys, ","), scope.getTableOptions())).Error) + scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v, PRIMARY KEY (%v))%s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), strings.Join(primaryKeys, ","), scope.getTableOptions())).Error) } scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler) } @@ -1127,19 +1166,19 @@ func (scope *Scope) createTable() *Scope { primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ",")) } - scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v) %s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec() + scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v)%s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec() scope.autoIndex() return scope } func (scope *Scope) dropTable() *Scope { - scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.QuotedTableName())).Exec() + scope.Raw(fmt.Sprintf("DROP TABLE %v%s", scope.QuotedTableName(), scope.getTableOptions())).Exec() return scope } func (scope *Scope) modifyColumn(column string, typ string) { - scope.Raw(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", scope.QuotedTableName(), scope.Quote(column), typ)).Exec() + scope.db.AddError(scope.Dialect().ModifyColumn(scope.QuotedTableName(), scope.Quote(column), typ)) } func (scope *Scope) dropColumn(column string) { @@ -1175,6 +1214,16 @@ func (scope *Scope) addForeignKey(field string, dest string, onDelete string, on scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName), scope.quoteIfPossible(field), dest, onDelete, onUpdate)).Exec() } +func (scope *Scope) removeForeignKey(field string, dest string) { + keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest) + + if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) { + return + } + var query = `ALTER TABLE %s DROP CONSTRAINT %s;` + scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName))).Exec() +} + func (scope *Scope) removeIndex(indexName string) { scope.Dialect().RemoveIndex(scope.TableName(), indexName) } @@ -1229,13 +1278,13 @@ func (scope *Scope) autoIndex() *Scope { } for name, columns := range indexes { - if db := scope.NewDB().Model(scope.Value).AddIndex(name, columns...); db.Error != nil { + if db := scope.NewDB().Table(scope.TableName()).Model(scope.Value).AddIndex(name, columns...); db.Error != nil { scope.db.AddError(db.Error) } } for name, columns := range uniqueIndexes { - if db := scope.NewDB().Model(scope.Value).AddUniqueIndex(name, columns...); db.Error != nil { + if db := scope.NewDB().Table(scope.TableName()).Model(scope.Value).AddUniqueIndex(name, columns...); db.Error != nil { scope.db.AddError(db.Error) } } diff --git a/scope_test.go b/scope_test.go index 42458995d..3018f3509 100644 --- a/scope_test.go +++ b/scope_test.go @@ -1,8 +1,12 @@ package gorm_test import ( - "github.com/jinzhu/gorm" + "encoding/hex" + "math/rand" + "strings" "testing" + + "github.com/jinzhu/gorm" ) func NameIn1And2(d *gorm.DB) *gorm.DB { @@ -41,3 +45,36 @@ func TestScopes(t *testing.T) { t.Errorf("Should found two users's name in 1, 3") } } + +func randName() string { + data := make([]byte, 8) + rand.Read(data) + + return "n-" + hex.EncodeToString(data) +} + +func TestValuer(t *testing.T) { + name := randName() + + origUser := User{Name: name, Age: 1, Password: EncryptedData("pass1"), PasswordHash: []byte("abc")} + if err := DB.Save(&origUser).Error; err != nil { + t.Errorf("No error should happen when saving user, but got %v", err) + } + + var user2 User + if err := DB.Where("name = ? AND password = ? AND password_hash = ?", name, EncryptedData("pass1"), []byte("abc")).First(&user2).Error; err != nil { + t.Errorf("No error should happen when querying user with valuer, but got %v", err) + } +} + +func TestFailedValuer(t *testing.T) { + name := randName() + + err := DB.Exec("INSERT INTO users(name, password) VALUES(?, ?)", name, EncryptedData("xpass1")).Error + + if err == nil { + t.Errorf("There should be an error should happen when insert data") + } else if !strings.HasPrefix(err.Error(), "Should not start with") { + t.Errorf("The error should be returned from Valuer, but get %v", err) + } +} diff --git a/test_all.sh b/test_all.sh index 80b319bf0..5cfb3321a 100755 --- a/test_all.sh +++ b/test_all.sh @@ -1,5 +1,5 @@ dialects=("postgres" "mysql" "mssql" "sqlite") for dialect in "${dialects[@]}" ; do - GORM_DIALECT=${dialect} go test + DEBUG=false GORM_DIALECT=${dialect} go test done diff --git a/wercker.yml b/wercker.yml index ff6fb17cc..0c3e73ef1 100644 --- a/wercker.yml +++ b/wercker.yml @@ -2,17 +2,72 @@ box: golang services: - - id: mariadb:10.0 + - name: mariadb + id: mariadb:latest env: MYSQL_DATABASE: gorm MYSQL_USER: gorm MYSQL_PASSWORD: gorm MYSQL_RANDOM_ROOT_PASSWORD: "yes" - - id: postgres + - name: mysql57 + id: mysql:5.7 + env: + MYSQL_DATABASE: gorm + MYSQL_USER: gorm + MYSQL_PASSWORD: gorm + MYSQL_RANDOM_ROOT_PASSWORD: "yes" + - name: mysql56 + id: mysql:5.6 + env: + MYSQL_DATABASE: gorm + MYSQL_USER: gorm + MYSQL_PASSWORD: gorm + MYSQL_RANDOM_ROOT_PASSWORD: "yes" + - name: mysql55 + id: mysql:5.5 + env: + MYSQL_DATABASE: gorm + MYSQL_USER: gorm + MYSQL_PASSWORD: gorm + MYSQL_RANDOM_ROOT_PASSWORD: "yes" + - name: postgres + id: postgres:latest + env: + POSTGRES_USER: gorm + POSTGRES_PASSWORD: gorm + POSTGRES_DB: gorm + - name: postgres96 + id: postgres:9.6 + env: + POSTGRES_USER: gorm + POSTGRES_PASSWORD: gorm + POSTGRES_DB: gorm + - name: postgres95 + id: postgres:9.5 + env: + POSTGRES_USER: gorm + POSTGRES_PASSWORD: gorm + POSTGRES_DB: gorm + - name: postgres94 + id: postgres:9.4 + env: + POSTGRES_USER: gorm + POSTGRES_PASSWORD: gorm + POSTGRES_DB: gorm + - name: postgres93 + id: postgres:9.3 env: POSTGRES_USER: gorm POSTGRES_PASSWORD: gorm POSTGRES_DB: gorm + - name: mssql + id: mcmoe/mssqldocker:latest + env: + ACCEPT_EULA: Y + SA_PASSWORD: LoremIpsum86 + MSSQL_DB: gorm + MSSQL_USER: gorm + MSSQL_PASSWORD: LoremIpsum86 # The steps that will be executed in the build pipeline build: @@ -43,11 +98,51 @@ build: go test ./... - script: - name: test mysql + name: test mariadb + code: | + GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mariadb:3306)/gorm?charset=utf8&parseTime=True" go test ./... + + - script: + name: test mysql5.7 code: | - GORM_DIALECT=mysql GORM_DBADDRESS=mariadb:3306 go test ./... + GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql57:3306)/gorm?charset=utf8&parseTime=True" go test ./... + + - script: + name: test mysql5.6 + code: | + GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql56:3306)/gorm?charset=utf8&parseTime=True" go test ./... + + - script: + name: test mysql5.5 + code: | + GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql55:3306)/gorm?charset=utf8&parseTime=True" go test ./... - script: name: test postgres code: | - GORM_DIALECT=postgres GORM_DBHOST=postgres go test ./... + GORM_DIALECT=postgres GORM_DSN="host=postgres user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./... + + - script: + name: test postgres96 + code: | + GORM_DIALECT=postgres GORM_DSN="host=postgres96 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./... + + - script: + name: test postgres95 + code: | + GORM_DIALECT=postgres GORM_DSN="host=postgres95 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./... + + - script: + name: test postgres94 + code: | + GORM_DIALECT=postgres GORM_DSN="host=postgres94 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./... + + - script: + name: test postgres93 + code: | + GORM_DIALECT=postgres GORM_DSN="host=postgres93 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test ./... + + - script: + name: test mssql + code: | + GORM_DIALECT=mssql GORM_DSN="sqlserver://gorm:LoremIpsum86@mssql:1433?database=gorm" go test ./...