Skip to content

Commit

Permalink
Merge branch 'jinzhu/master'
Browse files Browse the repository at this point in the history
  • Loading branch information
hb-chen committed Nov 21, 2018
2 parents 3683d88 + 472c70c commit 3f51bb1
Show file tree
Hide file tree
Showing 27 changed files with 730 additions and 252 deletions.
8 changes: 5 additions & 3 deletions association.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,15 +267,16 @@ func (association *Association) Count() int {
query = scope.DB()
)

if relationship.Kind == "many_to_many" {
switch relationship.Kind {
case "many_to_many":
query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, query, scope.Value)
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
case "has_many", "has_one":
primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value)
query = query.Where(
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
toQueryValues(primaryKeys)...,
)
} else if relationship.Kind == "belongs_to" {
case "belongs_to":
primaryKeys := scope.getColumnAsArray(relationship.ForeignFieldNames, scope.Value)
query = query.Where(
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)),
Expand Down Expand Up @@ -367,6 +368,7 @@ func (association *Association) saveAssociations(values ...interface{}) *Associa
return association
}

// setErr set error when the error is not nil. And return Association.
func (association *Association) setErr(err error) *Association {
if err != nil {
association.Error = err
Expand Down
2 changes: 1 addition & 1 deletion callback_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func createCallback(scope *Scope) {

for _, field := range scope.Fields() {
if scope.changeableField(field) {
if field.IsNormal {
if field.IsNormal && !field.IsIgnored {
if field.IsBlank && field.HasDefaultValue {
blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName))
scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue)
Expand Down
5 changes: 5 additions & 0 deletions callback_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ func queryCallback(scope *Scope) {
if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {
return
}

//we are only preloading relations, dont touch base model
if _, skip := scope.InstanceGet("gorm:only_preload"); skip {
return
}

defer scope.trace(NowFunc())

Expand Down
45 changes: 31 additions & 14 deletions callback_query_preload.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,14 @@ func preloadCallback(scope *Scope) {
return
}

if _, ok := scope.Get("gorm:auto_preload"); ok {
autoPreload(scope)
if ap, ok := scope.Get("gorm:auto_preload"); ok {
// If gorm:auto_preload IS NOT a bool then auto preload.
// Else if it IS a bool, use the value
if apb, ok := ap.(bool); !ok {
autoPreload(scope)
} else if apb {
autoPreload(scope)
}
}

if scope.Search.preload == nil || scope.HasError() {
Expand Down Expand Up @@ -94,7 +100,7 @@ func autoPreload(scope *Scope) {
continue
}

if val, ok := field.TagSettings["PRELOAD"]; ok {
if val, ok := field.TagSettingsGet("PRELOAD"); ok {
if preload, err := strconv.ParseBool(val); err != nil {
scope.Err(errors.New("invalid preload option"))
return
Expand Down Expand Up @@ -155,14 +161,17 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{})
)

if indirectScopeValue.Kind() == reflect.Slice {
foreignValuesToResults := make(map[string]reflect.Value)
for i := 0; i < resultsValue.Len(); i++ {
result := resultsValue.Index(i)
foreignValues := toString(getValueFromFields(result, relation.ForeignFieldNames))
foreignValuesToResults[foreignValues] = result
}
for j := 0; j < indirectScopeValue.Len(); j++ {
for i := 0; i < resultsValue.Len(); i++ {
result := resultsValue.Index(i)
foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
if indirectValue := indirect(indirectScopeValue.Index(j)); equalAsString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames), foreignValues) {
indirectValue.FieldByName(field.Name).Set(result)
break
}
indirectValue := indirect(indirectScopeValue.Index(j))
valueString := toString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames))
if result, found := foreignValuesToResults[valueString]; found {
indirectValue.FieldByName(field.Name).Set(result)
}
}
} else {
Expand Down Expand Up @@ -249,13 +258,21 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{
indirectScopeValue = scope.IndirectValue()
)

foreignFieldToObjects := make(map[string][]*reflect.Value)
if indirectScopeValue.Kind() == reflect.Slice {
for j := 0; j < indirectScopeValue.Len(); j++ {
object := indirect(indirectScopeValue.Index(j))
valueString := toString(getValueFromFields(object, relation.ForeignFieldNames))
foreignFieldToObjects[valueString] = append(foreignFieldToObjects[valueString], &object)
}
}

for i := 0; i < resultsValue.Len(); i++ {
result := resultsValue.Index(i)
if indirectScopeValue.Kind() == reflect.Slice {
value := getValueFromFields(result, relation.AssociationForeignFieldNames)
for j := 0; j < indirectScopeValue.Len(); j++ {
object := indirect(indirectScopeValue.Index(j))
if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) {
valueString := toString(getValueFromFields(result, relation.AssociationForeignFieldNames))
if objects, found := foreignFieldToObjects[valueString]; found {
for _, object := range objects {
object.FieldByName(field.Name).Set(result)
}
}
Expand Down
14 changes: 7 additions & 7 deletions callback_save.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCrea

if v, ok := value.(string); ok {
v = strings.ToLower(v)
if v == "false" || v != "skip" {
return false
}
return v == "true"
}

return true
Expand All @@ -36,26 +34,28 @@ func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCrea
if value, ok := scope.Get("gorm:save_associations"); ok {
autoUpdate = checkTruth(value)
autoCreate = autoUpdate
} else if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; ok {
saveReference = autoUpdate
} else if value, ok := field.TagSettingsGet("SAVE_ASSOCIATIONS"); ok {
autoUpdate = checkTruth(value)
autoCreate = autoUpdate
saveReference = autoUpdate
}

if value, ok := scope.Get("gorm:association_autoupdate"); ok {
autoUpdate = checkTruth(value)
} else if value, ok := field.TagSettings["ASSOCIATION_AUTOUPDATE"]; ok {
} else if value, ok := field.TagSettingsGet("ASSOCIATION_AUTOUPDATE"); ok {
autoUpdate = checkTruth(value)
}

if value, ok := scope.Get("gorm:association_autocreate"); ok {
autoCreate = checkTruth(value)
} else if value, ok := field.TagSettings["ASSOCIATION_AUTOCREATE"]; ok {
} else if value, ok := field.TagSettingsGet("ASSOCIATION_AUTOCREATE"); ok {
autoCreate = checkTruth(value)
}

if value, ok := scope.Get("gorm:association_save_reference"); ok {
saveReference = checkTruth(value)
} else if value, ok := field.TagSettings["ASSOCIATION_SAVE_REFERENCE"]; ok {
} else if value, ok := field.TagSettingsGet("ASSOCIATION_SAVE_REFERENCE"); ok {
saveReference = checkTruth(value)
}
}
Expand Down
4 changes: 3 additions & 1 deletion callback_update.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ func updateCallback(scope *Scope) {
for _, field := range scope.Fields() {
if scope.changeableField(field) {
if !field.IsPrimaryKey && field.IsNormal {
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
if !field.IsForeignKey || !field.IsBlank || !field.HasDefaultValue {
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
}
} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
for _, foreignKey := range relationship.ForeignDBNames {
if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
Expand Down
16 changes: 12 additions & 4 deletions dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,18 @@ func RegisterDialect(name string, dialect Dialect) {
dialectsMap[name] = dialect
}

// GetDialect gets the dialect for the specified dialect name
func GetDialect(name string) (dialect Dialect, ok bool) {
dialect, ok = dialectsMap[name]
return
}

// ParseFieldStructForDialect get field's sql data type
var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fieldValue reflect.Value, sqlType string, size int, additionalType string) {
// Get redirected field type
var (
reflectType = field.Struct.Type
dataType = field.TagSettings["TYPE"]
dataType, _ = field.TagSettingsGet("TYPE")
)

for reflectType.Kind() == reflect.Ptr {
Expand Down Expand Up @@ -106,15 +112,17 @@ var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fiel
}

// Default Size
if num, ok := field.TagSettings["SIZE"]; ok {
if num, ok := field.TagSettingsGet("SIZE"); ok {
size, _ = strconv.Atoi(num)
} else {
size = 255
}

// Default type from tag setting
additionalType = field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"]
if value, ok := field.TagSettings["DEFAULT"]; ok {
notNull, _ := field.TagSettingsGet("NOT NULL")
unique, _ := field.TagSettingsGet("UNIQUE")
additionalType = notNull + " " + unique
if value, ok := field.TagSettingsGet("DEFAULT"); ok {
additionalType = additionalType + " DEFAULT " + value
}

Expand Down
2 changes: 1 addition & 1 deletion dialect_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func (commonDialect) Quote(key string) string {
}

func (s *commonDialect) fieldCanAutoIncrement(field *StructField) bool {
if value, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
if value, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok {
return strings.ToLower(value) != "false"
}
return field.IsPrimaryKey
Expand Down
22 changes: 11 additions & 11 deletions dialect_mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ func (s *mysql) DataTypeOf(field *StructField) string {

// MySQL allows only one auto increment column per table, and it must
// be a KEY column.
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
if _, ok = field.TagSettings["INDEX"]; !ok && !field.IsPrimaryKey {
delete(field.TagSettings, "AUTO_INCREMENT")
if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok {
if _, ok = field.TagSettingsGet("INDEX"); !ok && !field.IsPrimaryKey {
field.TagSettingsDelete("AUTO_INCREMENT")
}
}

Expand All @@ -45,42 +45,42 @@ func (s *mysql) DataTypeOf(field *StructField) string {
sqlType = "boolean"
case reflect.Int8:
if s.fieldCanAutoIncrement(field) {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
sqlType = "tinyint AUTO_INCREMENT"
} else {
sqlType = "tinyint"
}
case reflect.Int, reflect.Int16, reflect.Int32:
if s.fieldCanAutoIncrement(field) {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
sqlType = "int AUTO_INCREMENT"
} else {
sqlType = "int"
}
case reflect.Uint8:
if s.fieldCanAutoIncrement(field) {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
sqlType = "tinyint unsigned AUTO_INCREMENT"
} else {
sqlType = "tinyint unsigned"
}
case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if s.fieldCanAutoIncrement(field) {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
sqlType = "int unsigned AUTO_INCREMENT"
} else {
sqlType = "int unsigned"
}
case reflect.Int64:
if s.fieldCanAutoIncrement(field) {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
sqlType = "bigint AUTO_INCREMENT"
} else {
sqlType = "bigint"
}
case reflect.Uint64:
if s.fieldCanAutoIncrement(field) {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
sqlType = "bigint unsigned AUTO_INCREMENT"
} else {
sqlType = "bigint unsigned"
Expand All @@ -96,11 +96,11 @@ func (s *mysql) DataTypeOf(field *StructField) string {
case reflect.Struct:
if _, ok := dataValue.Interface().(time.Time); ok {
precision := ""
if p, ok := field.TagSettings["PRECISION"]; ok {
if p, ok := field.TagSettingsGet("PRECISION"); ok {
precision = fmt.Sprintf("(%s)", p)
}

if _, ok := field.TagSettings["NOT NULL"]; ok {
if _, ok := field.TagSettingsGet("NOT NULL"); ok {
sqlType = fmt.Sprintf("timestamp%v", precision)
} else {
sqlType = fmt.Sprintf("timestamp%v NULL", precision)
Expand Down
6 changes: 3 additions & 3 deletions dialect_postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,22 @@ func (s *postgres) DataTypeOf(field *StructField) string {
sqlType = "boolean"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uintptr:
if s.fieldCanAutoIncrement(field) {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
sqlType = "serial"
} else {
sqlType = "integer"
}
case reflect.Int64, reflect.Uint32, reflect.Uint64:
if s.fieldCanAutoIncrement(field) {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
sqlType = "bigserial"
} else {
sqlType = "bigint"
}
case reflect.Float32, reflect.Float64:
sqlType = "numeric"
case reflect.String:
if _, ok := field.TagSettings["SIZE"]; !ok {
if _, ok := field.TagSettingsGet("SIZE"); !ok {
size = 0 // if SIZE haven't been set, use `text` as the default type, as there are no performance different
}

Expand Down
4 changes: 2 additions & 2 deletions dialect_sqlite3.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@ func (s *sqlite3) DataTypeOf(field *StructField) string {
sqlType = "bool"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if s.fieldCanAutoIncrement(field) {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
sqlType = "integer primary key autoincrement"
} else {
sqlType = "integer"
}
case reflect.Int64, reflect.Uint64:
if s.fieldCanAutoIncrement(field) {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
sqlType = "integer primary key autoincrement"
} else {
sqlType = "bigint"
Expand Down
Loading

0 comments on commit 3f51bb1

Please sign in to comment.