Skip to content

Commit

Permalink
fix: check json.unmarshaler when deref (#532)
Browse files Browse the repository at this point in the history
  • Loading branch information
AsterDY authored Sep 26, 2023
1 parent a2d3aa9 commit 805ad67
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 15 deletions.
38 changes: 26 additions & 12 deletions internal/decoder/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -527,40 +527,47 @@ func (self *_Compiler) compile(vt reflect.Type) (ret _Program, err error) {
return
}

func (self *_Compiler) compileOne(p *_Program, sp int, vt reflect.Type) {
/* check for recursive nesting */
ok := self.tab[vt]
if ok {
p.rtt(_OP_recurse, vt)
return
}

func (self *_Compiler) checkMarshaler(p *_Program, vt reflect.Type) bool {
pt := reflect.PtrTo(vt)

/* check for `json.Unmarshaler` with pointer receiver */
if pt.Implements(jsonUnmarshalerType) {
p.rtt(_OP_unmarshal_p, pt)
return
return true
}

/* check for `json.Unmarshaler` */
if vt.Implements(jsonUnmarshalerType) {
p.add(_OP_lspace)
self.compileUnmarshalJson(p, vt)
return
return true
}

/* check for `encoding.TextMarshaler` with pointer receiver */
if pt.Implements(encodingTextUnmarshalerType) {
p.add(_OP_lspace)
self.compileUnmarshalTextPtr(p, pt)
return
return true
}

/* check for `encoding.TextUnmarshaler` */
if vt.Implements(encodingTextUnmarshalerType) {
p.add(_OP_lspace)
self.compileUnmarshalText(p, vt)
return true
}
return false
}

func (self *_Compiler) compileOne(p *_Program, sp int, vt reflect.Type) {
/* check for recursive nesting */
ok := self.tab[vt]
if ok {
p.rtt(_OP_recurse, vt)
return
}

if self.checkMarshaler(p, vt) {
return
}

Expand Down Expand Up @@ -683,6 +690,9 @@ func (self *_Compiler) compilePtr(p *_Program, sp int, et reflect.Type) {

/* dereference all the way down */
for et.Kind() == reflect.Ptr {
if self.checkMarshaler(p, et) {
return
}
et = et.Elem()
p.rtt(_OP_deref, et)
}
Expand All @@ -695,7 +705,7 @@ func (self *_Compiler) compilePtr(p *_Program, sp int, et reflect.Type) {
/* enter the recursion */
p.add(_OP_lspace)
self.tab[et] = true

/* not inline the pointer type
* recursing the defined pointer type's elem will casue issue379.
*/
Expand All @@ -705,8 +715,12 @@ func (self *_Compiler) compilePtr(p *_Program, sp int, et reflect.Type) {

j := p.pc()
p.add(_OP_goto)

// set val pointer as nil
p.pin(i)
p.add(_OP_nil_1)

// nothing todo
p.pin(j)
}

Expand Down
39 changes: 36 additions & 3 deletions issue_test/issue379_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
package issue_test

import (
`testing`
`encoding/json`
`testing`

`github.com/bytedance/sonic`
`github.com/stretchr/testify/assert`
`github.com/stretchr/testify/require`
)

Expand All @@ -28,6 +30,7 @@ type Foo struct {
}

func (f *Foo) UnmarshalJSON(data []byte) error {
println("UnmarshalJSON called!!!")
f.Name = "Unmarshaler"
return nil
}
Expand All @@ -38,11 +41,16 @@ func TestIssue379(t *testing.T) {
tests := []struct{
data string
newf func() interface{}
equal func(exp, act interface{}) bool
} {
{
data: `{"Name":"MyPtr"}`,
newf: func() interface{} { return &Foo{} },
},
{
data: `{"Name":"MyPtr"}`,
newf: func() interface{} { ptr := &Foo{}; return &ptr },
},
{
data: `{"Name":"MyPtr"}`,
newf: func() interface{} { return MyPtr(&Foo{}) },
Expand All @@ -55,13 +63,27 @@ func TestIssue379(t *testing.T) {
data: `null`,
newf: func() interface{} { return MyPtr(&Foo{}) },
},
{
data: `null`,
newf: func() interface{} { ptr := MyPtr(&Foo{}); return &ptr },
equal: func(exp, act interface{}) bool {
isExpNil := exp == nil || *(exp.(*MyPtr)) == nil
isActNil := act == nil || *(act.(*MyPtr)) == nil
return isActNil == isExpNil
},
},
{
data: `null`,
newf: func() interface{} { return &Foo{} },
},
{
data: `null`,
newf: func() interface{} { ptr := MyPtr(&Foo{}); return &ptr },
newf: func() interface{} { ptr := &Foo{}; return &ptr },
equal: func(exp, act interface{}) bool {
isExpNil := exp == nil || *(exp.(**Foo)) == nil
isActNil := act == nil || *(act.(**Foo)) == nil
return isActNil == isExpNil
},
},
{
data: `{"map":{"Name":"MyPtr"}}`,
Expand Down Expand Up @@ -89,11 +111,22 @@ func TestIssue379(t *testing.T) {
},
}

for _, tt := range tests {
for i, tt := range tests {
println(i)
jv, sv := tt.newf(), tt.newf()
jerr := json.Unmarshal([]byte(tt.data), jv)
serr := sonic.Unmarshal([]byte(tt.data), sv)
require.Equal(t, jv, sv)
require.Equal(t, jerr, serr)

jv, sv = tt.newf(), tt.newf()
jerr = json.Unmarshal([]byte(tt.data), &jv)
serr = sonic.Unmarshal([]byte(tt.data), &sv)
if !assert.ObjectsAreEqual(jv, sv) {
if tt.equal == nil || !tt.equal(jv, sv) {
t.Fatal()
}
}
require.Equal(t, jerr, serr)
}
}

0 comments on commit 805ad67

Please sign in to comment.