From 3eef848d7f7b01cad9cf27e3ffca456686d03704 Mon Sep 17 00:00:00 2001 From: tfrench-uber Date: Wed, 29 Dec 2021 16:44:44 -0500 Subject: [PATCH] validate input and output before unmarshal --- unmarshal.go | 10 ++++++- unmarshal_test.go | 75 +++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 82 insertions(+), 3 deletions(-) diff --git a/unmarshal.go b/unmarshal.go index e66090d..957347c 100644 --- a/unmarshal.go +++ b/unmarshal.go @@ -3,6 +3,7 @@ package njson import ( "encoding/json" "errors" + "fmt" "reflect" "github.com/tidwall/gjson" @@ -14,6 +15,9 @@ var jsonNumberType = reflect.TypeOf(json.Number("")) // Unmarshal used to unmarshal nested json using "njson" tag func Unmarshal(data []byte, v interface{}) (err error) { + if !gjson.ValidBytes(data) { + return fmt.Errorf("invalid json: %v", string(data)) + } // catch code panic and return error message defer func() { @@ -29,7 +33,11 @@ func Unmarshal(data []byte, v interface{}) (err error) { } }() - elem := reflect.ValueOf(v).Elem() + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Ptr || rv.IsNil() { + return fmt.Errorf("can't unmarshal to invalid type %v", reflect.TypeOf(v)) + } + elem := rv.Elem() typeOfT := elem.Type() for i := 0; i < elem.NumField(); i++ { field := elem.Field(i) diff --git a/unmarshal_test.go b/unmarshal_test.go index ea7a3d7..20b0d5a 100644 --- a/unmarshal_test.go +++ b/unmarshal_test.go @@ -8,6 +8,77 @@ import ( "github.com/google/go-cmp/cmp" ) +func TestUnmarshalInvalidJson(t *testing.T) { + json := ` + BAD JSON %% @ + ## + }` + + type Name struct { + First string `njson:"first"` + Last string `njson:"last"` + } + + type User struct { + Name Name `njson:"name"` + Age int `njson:"age"` + Friends []Name `njson:"friends"` + } + + actual := User{} + + if err := Unmarshal([]byte(json), &actual); err == nil { + t.Error("error should not be nil") + } +} + +func TestUnmarshalError(t *testing.T) { + json := ` + { + "name": {"first": "Mohamed", "last": "Shapan"}, + "age": 26, + "friends": [ + {"first": "Asma", "age": 26}, + {"first": "Ahmed", "age": 25}, + {"first": "Mahmoud", "age": 30} + ] + }` + + if err := Unmarshal([]byte(json), nil); err == nil { + t.Error("error should not be nil") + } +} + +func TestUnmarshalByValueError(t *testing.T) { + json := ` + { + "name": {"first": "Mohamed", "last": "Shapan"}, + "age": 26, + "friends": [ + {"first": "Asma", "age": 26}, + {"first": "Ahmed", "age": 25}, + {"first": "Mahmoud", "age": 30} + ] + }` + + type Name struct { + First string `njson:"first"` + Last string `njson:"last"` + } + + type User struct { + Name Name `njson:"name"` + Age int `njson:"age"` + Friends []Name `njson:"friends"` + } + + actual := User{} + + if err := Unmarshal([]byte(json), actual); err == nil { + t.Error("error should not be nil") + } +} + func TestUnmarshalSmall(t *testing.T) { json := ` { @@ -260,7 +331,7 @@ func TestUnmarshalSlices(t *testing.T) { [ [ 601, 602, 603 - ], + ] ] ], @@ -429,7 +500,7 @@ func TestUnmarshalComplex(t *testing.T) { ], "time_1": "2021-01-11T23:56:51.141Z", "time_2": "2021-01-11T23:56:51.141+01:00", - "time_3": "2021-01-11T23:56:51.141-01:00", + "time_3": "2021-01-11T23:56:51.141-01:00" } `