Skip to content

Commit

Permalink
fix: check enum type recursively (#64)
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaost authored Oct 23, 2024
1 parent 4378967 commit 5dc674c
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 2 deletions.
11 changes: 11 additions & 0 deletions internal/defs/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,17 @@ func (self *Type) Tag() Tag {
}
}

func (self *Type) IsEnum() bool {
switch self.T {
case T_enum:
return true
case T_pointer:
return self.V.IsEnum()
default:
return false
}
}

func (self *Type) Free() {
typePool.Put(self)
}
Expand Down
22 changes: 22 additions & 0 deletions internal/defs/types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,25 @@ func TestTypes_MapKeyType(t *testing.T) {
require.NoError(t, err)
fmt.Println(tt)
}

func TestTypes_Enum(t *testing.T) {
type EnumType int64
type Int32 int32
type StructWithEnum struct {
A EnumType `frugal:"1,optional,EnumType"`
B *EnumType `frugal:"2,optional,EnumType"`
C Int32 `frugal:"3,optional,Int32"`
D int64 `frugal:"4,optional,i64"`
}
ff, err := DoResolveFields(reflect.TypeOf(StructWithEnum{}))
require.NoError(t, err)
require.Len(t, ff, 4)
require.True(t, ff[0].Type.IsEnum())
require.Equal(t, ff[0].Type.T, T_enum)
require.True(t, ff[1].Type.IsEnum())
require.Equal(t, ff[1].Type.T, T_pointer)
require.Equal(t, ff[1].Type.V.T, T_enum)
require.False(t, ff[2].Type.IsEnum())
require.False(t, ff[3].Type.IsEnum())

}
106 changes: 105 additions & 1 deletion internal/reflect/decoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@ import (
"github.com/stretchr/testify/require"
)

func TestDecode(t *testing.T) {
func init() {
rand.Seed(time.Now().Unix())
}

func TestDecode(t *testing.T) {
type testcase struct {
name string
update func(p *TestTypes)
Expand Down Expand Up @@ -197,6 +199,108 @@ func TestDecode(t *testing.T) {
}
}

func TestDecodeOptional(t *testing.T) {
type testcase struct {
name string
update func(p *TestTypesOptional)
test func(t *testing.T, p1 *TestTypesOptional)
}

var (
vInt16 = int16(rand.Uint32() & 0xffff)
vInt32 = int32(rand.Uint32())
vInt64 = int64(rand.Uint64())
vFloat64 = math.Float64frombits(rand.Uint64())
vTrue = true
vString = "hello"
vByte = int8(0x55)
vEnum = Numberz(int32(rand.Uint32()))
)

for math.IsNaN(vFloat64) { // fix test failure
vFloat64 = math.Float64frombits(rand.Uint64())
}

testcases := []testcase{
{
name: "case_bool",
update: func(p0 *TestTypesOptional) { p0.FBool = &vTrue },
test: func(t *testing.T, p1 *TestTypesOptional) { assert.Equal(t, vTrue, *p1.FBool) },
},
{
name: "case_string",
update: func(p0 *TestTypesOptional) { p0.String_ = &vString },
test: func(t *testing.T, p1 *TestTypesOptional) { assert.Equal(t, vString, *p1.String_) },
},
{
name: "case_byte",
update: func(p0 *TestTypesOptional) { p0.FByte = &vByte },
test: func(t *testing.T, p1 *TestTypesOptional) { assert.Equal(t, vByte, *p1.FByte) },
},
{
name: "case_int8",
update: func(p0 *TestTypesOptional) { p0.I8 = &vByte },
test: func(t *testing.T, p1 *TestTypesOptional) { assert.Equal(t, vByte, *p1.I8) },
},
{
name: "case_int16",
update: func(p0 *TestTypesOptional) { p0.I16 = &vInt16 },
test: func(t *testing.T, p1 *TestTypesOptional) { assert.Equal(t, vInt16, *p1.I16) },
},
{
name: "case_int32",
update: func(p0 *TestTypesOptional) { p0.I32 = &vInt32 },
test: func(t *testing.T, p1 *TestTypesOptional) { assert.Equal(t, vInt32, *p1.I32) },
},
{
name: "case_int64",
update: func(p0 *TestTypesOptional) { p0.I64 = &vInt64 },
test: func(t *testing.T, p1 *TestTypesOptional) { assert.Equal(t, vInt64, *p1.I64) },
},
{
name: "case_float64",
update: func(p0 *TestTypesOptional) { p0.Double = &vFloat64 },
test: func(t *testing.T, p1 *TestTypesOptional) { assert.Equal(t, vFloat64, *p1.Double) },
},
{
name: "case_enum",
update: func(p0 *TestTypesOptional) { p0.Enum = &vEnum },
test: func(t *testing.T, p1 *TestTypesOptional) { assert.Equal(t, vEnum, *p1.Enum) },
},
{
name: "case_typedef",
update: func(p0 *TestTypesOptional) { p0.UID = &vInt64 },
test: func(t *testing.T, p1 *TestTypesOptional) { assert.Equal(t, vInt64, *p1.UID) },
},
}
for _, tc := range testcases {
name := tc.name
updatef := tc.update
testf := tc.test
t.Run(name, func(t *testing.T) {
p0 := NewTestTypesOptional()
updatef(p0) // update by testcase func

b := make([]byte, EncodedSize(p0))
n, err := Encode(b, p0)
require.NoError(t, err)
require.Equal(t, len(b), n)

// verify by gopkg thrift
n, err = thrift.Binary.Skip(b, thrift.TType(tSTRUCT))
require.NoError(t, err)
require.Equal(t, n, len(b))

p1 := &TestTypesOptional{}
n, err = Decode(b, p1)
require.NoError(t, err)
require.Equal(t, len(b), n)

testf(t, p1) // test by testcase func
})
}
}

func TestDecodeRequired(t *testing.T) {
type S0 struct {
V *bool `frugal:"1,optional,bool"`
Expand Down
2 changes: 1 addition & 1 deletion internal/reflect/ttype.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ func newTType(x *defs.Type) *tType {
t.T = ttype(x.Tag())
t.WT = t.T
t.Tag = x.T
if t.Tag == defs.T_enum {
if x.IsEnum() {
t.T = tENUM
}
t.RT = x.S
Expand Down

0 comments on commit 5dc674c

Please sign in to comment.