diff --git a/internal/defs/types.go b/internal/defs/types.go index 7e230ef..5002e99 100644 --- a/internal/defs/types.go +++ b/internal/defs/types.go @@ -127,6 +127,17 @@ func (self *Type) Tag() Tag { } } +func (self *Type) IsEnum() bool { + switch self.T { + case T_enum: + return true + case T_pointer: + return self.V.IsEnum() + default: + return false + } +} + func (self *Type) Free() { typePool.Put(self) } diff --git a/internal/defs/types_test.go b/internal/defs/types_test.go index 7ac67d2..a2c4a2d 100644 --- a/internal/defs/types_test.go +++ b/internal/defs/types_test.go @@ -37,3 +37,25 @@ func TestTypes_MapKeyType(t *testing.T) { require.NoError(t, err) fmt.Println(tt) } + +func TestTypes_Enum(t *testing.T) { + type EnumType int64 + type Int32 int32 + type StructWithEnum struct { + A EnumType `frugal:"1,optional,EnumType"` + B *EnumType `frugal:"2,optional,EnumType"` + C Int32 `frugal:"3,optional,Int32"` + D int64 `frugal:"4,optional,i64"` + } + ff, err := DoResolveFields(reflect.TypeOf(StructWithEnum{})) + require.NoError(t, err) + require.Len(t, ff, 4) + require.True(t, ff[0].Type.IsEnum()) + require.Equal(t, ff[0].Type.T, T_enum) + require.True(t, ff[1].Type.IsEnum()) + require.Equal(t, ff[1].Type.T, T_pointer) + require.Equal(t, ff[1].Type.V.T, T_enum) + require.False(t, ff[2].Type.IsEnum()) + require.False(t, ff[3].Type.IsEnum()) + +} diff --git a/internal/reflect/decoder_test.go b/internal/reflect/decoder_test.go index 070b962..12ebb63 100644 --- a/internal/reflect/decoder_test.go +++ b/internal/reflect/decoder_test.go @@ -27,9 +27,11 @@ import ( "github.com/stretchr/testify/require" ) -func TestDecode(t *testing.T) { +func init() { rand.Seed(time.Now().Unix()) +} +func TestDecode(t *testing.T) { type testcase struct { name string update func(p *TestTypes) @@ -197,6 +199,108 @@ func TestDecode(t *testing.T) { } } +func TestDecodeOptional(t *testing.T) { + type testcase struct { + name string + update func(p *TestTypesOptional) + test func(t *testing.T, p1 *TestTypesOptional) + } + + var ( + vInt16 = int16(rand.Uint32() & 0xffff) + vInt32 = int32(rand.Uint32()) + vInt64 = int64(rand.Uint64()) + vFloat64 = math.Float64frombits(rand.Uint64()) + vTrue = true + vString = "hello" + vByte = int8(0x55) + vEnum = Numberz(int32(rand.Uint32())) + ) + + for math.IsNaN(vFloat64) { // fix test failure + vFloat64 = math.Float64frombits(rand.Uint64()) + } + + testcases := []testcase{ + { + name: "case_bool", + update: func(p0 *TestTypesOptional) { p0.FBool = &vTrue }, + test: func(t *testing.T, p1 *TestTypesOptional) { assert.Equal(t, vTrue, *p1.FBool) }, + }, + { + name: "case_string", + update: func(p0 *TestTypesOptional) { p0.String_ = &vString }, + test: func(t *testing.T, p1 *TestTypesOptional) { assert.Equal(t, vString, *p1.String_) }, + }, + { + name: "case_byte", + update: func(p0 *TestTypesOptional) { p0.FByte = &vByte }, + test: func(t *testing.T, p1 *TestTypesOptional) { assert.Equal(t, vByte, *p1.FByte) }, + }, + { + name: "case_int8", + update: func(p0 *TestTypesOptional) { p0.I8 = &vByte }, + test: func(t *testing.T, p1 *TestTypesOptional) { assert.Equal(t, vByte, *p1.I8) }, + }, + { + name: "case_int16", + update: func(p0 *TestTypesOptional) { p0.I16 = &vInt16 }, + test: func(t *testing.T, p1 *TestTypesOptional) { assert.Equal(t, vInt16, *p1.I16) }, + }, + { + name: "case_int32", + update: func(p0 *TestTypesOptional) { p0.I32 = &vInt32 }, + test: func(t *testing.T, p1 *TestTypesOptional) { assert.Equal(t, vInt32, *p1.I32) }, + }, + { + name: "case_int64", + update: func(p0 *TestTypesOptional) { p0.I64 = &vInt64 }, + test: func(t *testing.T, p1 *TestTypesOptional) { assert.Equal(t, vInt64, *p1.I64) }, + }, + { + name: "case_float64", + update: func(p0 *TestTypesOptional) { p0.Double = &vFloat64 }, + test: func(t *testing.T, p1 *TestTypesOptional) { assert.Equal(t, vFloat64, *p1.Double) }, + }, + { + name: "case_enum", + update: func(p0 *TestTypesOptional) { p0.Enum = &vEnum }, + test: func(t *testing.T, p1 *TestTypesOptional) { assert.Equal(t, vEnum, *p1.Enum) }, + }, + { + name: "case_typedef", + update: func(p0 *TestTypesOptional) { p0.UID = &vInt64 }, + test: func(t *testing.T, p1 *TestTypesOptional) { assert.Equal(t, vInt64, *p1.UID) }, + }, + } + for _, tc := range testcases { + name := tc.name + updatef := tc.update + testf := tc.test + t.Run(name, func(t *testing.T) { + p0 := NewTestTypesOptional() + updatef(p0) // update by testcase func + + b := make([]byte, EncodedSize(p0)) + n, err := Encode(b, p0) + require.NoError(t, err) + require.Equal(t, len(b), n) + + // verify by gopkg thrift + n, err = thrift.Binary.Skip(b, thrift.TType(tSTRUCT)) + require.NoError(t, err) + require.Equal(t, n, len(b)) + + p1 := &TestTypesOptional{} + n, err = Decode(b, p1) + require.NoError(t, err) + require.Equal(t, len(b), n) + + testf(t, p1) // test by testcase func + }) + } +} + func TestDecodeRequired(t *testing.T) { type S0 struct { V *bool `frugal:"1,optional,bool"` diff --git a/internal/reflect/ttype.go b/internal/reflect/ttype.go index 05ba667..b8abeeb 100644 --- a/internal/reflect/ttype.go +++ b/internal/reflect/ttype.go @@ -160,7 +160,7 @@ func newTType(x *defs.Type) *tType { t.T = ttype(x.Tag()) t.WT = t.T t.Tag = x.T - if t.Tag == defs.T_enum { + if x.IsEnum() { t.T = tENUM } t.RT = x.S