diff --git a/pkg/conf/conf.go b/pkg/conf/conf.go index dc3d3eaf2e..0194812bbe 100644 --- a/pkg/conf/conf.go +++ b/pkg/conf/conf.go @@ -25,10 +25,11 @@ import ( "sync" "time" + "github.com/douyu/jupiter/pkg/util/xcast" "github.com/douyu/jupiter/pkg/util/xmap" "github.com/mitchellh/mapstructure" "github.com/pkg/errors" - xcast "github.com/spf13/cast" + "github.com/spf13/cast" ) // Configuration provides configuration for application. @@ -252,7 +253,7 @@ func GetString(key string) string { // GetString returns the value associated with the key as a string. func (c *Configuration) GetString(key string) string { - return xcast.ToString(c.Get(key)) + return cast.ToString(c.Get(key)) } // GetBool returns the value associated with the key as a boolean with default defaultConfiguration. @@ -262,7 +263,7 @@ func GetBool(key string) bool { // GetBool returns the value associated with the key as a boolean. func (c *Configuration) GetBool(key string) bool { - return xcast.ToBool(c.Get(key)) + return cast.ToBool(c.Get(key)) } // GetInt returns the value associated with the key as an integer with default defaultConfiguration. @@ -272,7 +273,7 @@ func GetInt(key string) int { // GetInt returns the value associated with the key as an integer. func (c *Configuration) GetInt(key string) int { - return xcast.ToInt(c.Get(key)) + return cast.ToInt(c.Get(key)) } // GetInt64 returns the value associated with the key as an integer with default defaultConfiguration. @@ -282,7 +283,17 @@ func GetInt64(key string) int64 { // GetInt64 returns the value associated with the key as an integer. func (c *Configuration) GetInt64(key string) int64 { - return xcast.ToInt64(c.Get(key)) + return cast.ToInt64(c.Get(key)) +} + +// GetInt64Slice returns the value associated with the key as an integer slice with default defaultConfiguration. +func GetInt64Slice(key string) []int64 { + return defaultConfiguration.GetInt64Slice(key) +} + +// GetInt64Slice returns the value associated with the key as an integer slice. +func (c *Configuration) GetInt64Slice(key string) []int64 { + return xcast.ToInt64Slice(key) } // GetFloat64 returns the value associated with the key as a float64 with default defaultConfiguration. @@ -292,7 +303,7 @@ func GetFloat64(key string) float64 { // GetFloat64 returns the value associated with the key as a float64. func (c *Configuration) GetFloat64(key string) float64 { - return xcast.ToFloat64(c.Get(key)) + return cast.ToFloat64(c.Get(key)) } // GetTime returns the value associated with the key as time with default defaultConfiguration. @@ -302,7 +313,7 @@ func GetTime(key string) time.Time { // GetTime returns the value associated with the key as time. func (c *Configuration) GetTime(key string) time.Time { - return xcast.ToTime(c.Get(key)) + return cast.ToTime(c.Get(key)) } // GetDuration returns the value associated with the key as a duration with default defaultConfiguration. @@ -312,7 +323,7 @@ func GetDuration(key string) time.Duration { // GetDuration returns the value associated with the key as a duration. func (c *Configuration) GetDuration(key string) time.Duration { - return xcast.ToDuration(c.Get(key)) + return cast.ToDuration(c.Get(key)) } // GetStringSlice returns the value associated with the key as a slice of strings with default defaultConfiguration. @@ -322,7 +333,7 @@ func GetStringSlice(key string) []string { // GetStringSlice returns the value associated with the key as a slice of strings. func (c *Configuration) GetStringSlice(key string) []string { - return xcast.ToStringSlice(c.Get(key)) + return cast.ToStringSlice(c.Get(key)) } // GetSlice returns the value associated with the key as a slice of strings with default defaultConfiguration. @@ -332,7 +343,7 @@ func GetSlice(key string) []interface{} { // GetSlice returns the value associated with the key as a slice of strings. func (c *Configuration) GetSlice(key string) []interface{} { - return xcast.ToSlice(c.Get(key)) + return cast.ToSlice(c.Get(key)) } // GetStringMap returns the value associated with the key as a map of interfaces with default defaultConfiguration. @@ -342,7 +353,7 @@ func GetStringMap(key string) map[string]interface{} { // GetStringMap returns the value associated with the key as a map of interfaces. func (c *Configuration) GetStringMap(key string) map[string]interface{} { - return xcast.ToStringMap(c.Get(key)) + return cast.ToStringMap(c.Get(key)) } // GetStringMapString returns the value associated with the key as a map of strings with default defaultConfiguration. @@ -352,7 +363,7 @@ func GetStringMapString(key string) map[string]string { // GetStringMapString returns the value associated with the key as a map of strings. func (c *Configuration) GetStringMapString(key string) map[string]string { - return xcast.ToStringMapString(c.Get(key)) + return cast.ToStringMapString(c.Get(key)) } // GetStringMapStringSlice returns the value associated with the key as a map to a slice of strings with default defaultConfiguration. @@ -362,7 +373,7 @@ func GetStringMapStringSlice(key string) map[string][]string { // GetStringMapStringSlice returns the value associated with the key as a map to a slice of strings. func (c *Configuration) GetStringMapStringSlice(key string) map[string][]string { - return xcast.ToStringMapStringSlice(c.Get(key)) + return cast.ToStringMapStringSlice(c.Get(key)) } // UnmarshalWithExpect unmarshal key, returns expect if failed @@ -438,7 +449,7 @@ func lookup(prefix string, target map[string]interface{}, data map[string]interf if prefix == "" { pp = k } - if dd, err := xcast.ToStringMapE(v); err == nil { + if dd, err := cast.ToStringMapE(v); err == nil { lookup(pp, dd, data, sep) } else { data[pp] = v diff --git a/pkg/util/xcast/cast.go b/pkg/util/xcast/cast.go new file mode 100644 index 0000000000..2742e442de --- /dev/null +++ b/pkg/util/xcast/cast.go @@ -0,0 +1,43 @@ +package xcast + +import ( + "fmt" + "reflect" + + "github.com/spf13/cast" +) + +// ToInt64Slice casts an interface to a []int64 type. +func ToInt64Slice(i interface{}) []int64 { + v, _ := ToInt64SliceE(i) + return v +} + +// ToIntSliceE casts an interface to a []int64 type. +func ToInt64SliceE(i interface{}) ([]int64, error) { + if i == nil { + return []int64{}, fmt.Errorf("unable to cast %#v of type %T to []int", i, i) + } + + switch v := i.(type) { + case []int64: + return v, nil + } + + kind := reflect.TypeOf(i).Kind() + switch kind { + case reflect.Slice, reflect.Array: + s := reflect.ValueOf(i) + a := make([]int64, s.Len()) + for j := 0; j < s.Len(); j++ { + val, err := cast.ToInt64E(s.Index(j).Interface()) + if err != nil { + return []int64{}, fmt.Errorf("unable to cast %#v of type %T to []int", i, i) + } + a[j] = val + } + return a, nil + default: + return []int64{}, fmt.Errorf("unable to cast %#v of type %T to []int", i, i) + } +} diff --git a/pkg/util/xcast/cast_test.go b/pkg/util/xcast/cast_test.go new file mode 100644 index 0000000000..807bee5b22 --- /dev/null +++ b/pkg/util/xcast/cast_test.go @@ -0,0 +1,42 @@ +package xcast + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestToInt64SliceE(t *testing.T) { + tests := []struct { + input interface{} + expect []int64 + iserr bool + }{ + {[]int{1, 3}, []int64{1, 3}, false}, + {[]interface{}{1.2, 3.2}, []int64{1, 3}, false}, + {[]string{"2", "3"}, []int64{2, 3}, false}, + {[2]string{"2", "3"}, []int64{2, 3}, false}, + // errors + {nil, nil, true}, + {testing.T{}, nil, true}, + {[]string{"foo", "bar"}, nil, true}, + } + + for i, test := range tests { + errmsg := fmt.Sprintf("i = %d", i) // assert helper message + + v, err := ToInt64SliceE(test.input) + if test.iserr { + assert.Error(t, err, errmsg) + continue + } + + assert.NoError(t, err, errmsg) + assert.Equal(t, test.expect, v, errmsg) + + // Non-E test + v = ToInt64Slice(test.input) + assert.Equal(t, test.expect, v, errmsg) + } +}