From 8f75d2abd15eac9859cb7ee570088a8d9cdd3720 Mon Sep 17 00:00:00 2001 From: Mike Naquin Date: Wed, 7 Feb 2024 12:32:15 -0800 Subject: [PATCH] Allow for custom struct tags such as `json`. This is useful when using structs from 3rd-party packages which almost certainly only have `json` tags and not `yaml` tags. This implementation does not change the default behavior for only looking for `yaml` tags, but allows the user to specify a list of tags to consider such as []string{"yaml", "json"} in an encoder or decoder. --- decode.go | 13 +++++------ decode_test.go | 27 +++++++++++++++++++++++ encode.go | 15 +++++++------ encode_test.go | 33 ++++++++++++++++++++++++++++ yaml.go | 58 +++++++++++++++++++++++++++++++++++++++++++++----- 5 files changed, 128 insertions(+), 18 deletions(-) diff --git a/decode.go b/decode.go index 0173b698..f168f63b 100644 --- a/decode.go +++ b/decode.go @@ -318,11 +318,12 @@ type decoder struct { stringMapType reflect.Type generalMapType reflect.Type - knownFields bool - uniqueKeys bool - decodeCount int - aliasCount int - aliasDepth int + knownFields bool + uniqueKeys bool + decodeCount int + aliasCount int + aliasDepth int + structTagKeys []string mergedFields map[interface{}]bool } @@ -876,7 +877,7 @@ func isStringMap(n *Node) bool { } func (d *decoder) mappingStruct(n *Node, out reflect.Value) (good bool) { - sinfo, err := getStructInfo(out.Type()) + sinfo, err := getStructInfo(out.Type(), d.structTagKeys) if err != nil { panic(err) } diff --git a/decode_test.go b/decode_test.go index 0364b0bb..a0803430 100644 --- a/decode_test.go +++ b/decode_test.go @@ -863,6 +863,33 @@ func (s *S) TestDecoderSingleDocument(c *C) { } } +func (s *S) TestDecoderUnmarshalJSONTags(c *C) { + type T struct { + Z int `json:"a"` + Y int `json:"b"` + } + var v T + decoder := yaml.NewDecoder(strings.NewReader("a: 1\nb: 2")) + decoder.SetStructTagKeys([]string{"json"}) + err := decoder.Decode(&v) + c.Assert(err, IsNil) + c.Assert(v, DeepEquals, T{1, 2}) +} + +func (s *S) TestDecoderUnmarshalYAMLAndJSONTags(c *C) { + type T struct { + Z int `yaml:"a" json:"b"` + Y int `yaml:"b" json:"a"` + X int `json:"c"` + } + var v T + decoder := yaml.NewDecoder(strings.NewReader("a: 1\nb: 2\nc: 3")) + decoder.SetStructTagKeys([]string{"yaml", "json"}) + err := decoder.Decode(&v) + c.Assert(err, IsNil) + c.Assert(v, DeepEquals, T{1, 2, 3}) +} + var decoderTests = []struct { data string values []interface{} diff --git a/encode.go b/encode.go index de9e72a3..46dd7244 100644 --- a/encode.go +++ b/encode.go @@ -29,12 +29,13 @@ import ( ) type encoder struct { - emitter yaml_emitter_t - event yaml_event_t - out []byte - flow bool - indent int - doneInit bool + emitter yaml_emitter_t + event yaml_event_t + out []byte + flow bool + indent int + doneInit bool + structTagKeys []string } func newEncoder() *encoder { @@ -212,7 +213,7 @@ func (e *encoder) fieldByIndex(v reflect.Value, index []int) (field reflect.Valu } func (e *encoder) structv(tag string, in reflect.Value) { - sinfo, err := getStructInfo(in.Type()) + sinfo, err := getStructInfo(in.Type(), e.structTagKeys) if err != nil { panic(err) } diff --git a/encode_test.go b/encode_test.go index 4a8bf2e2..ec9be594 100644 --- a/encode_test.go +++ b/encode_test.go @@ -538,6 +538,39 @@ func (s *S) TestEncoderWriteError(c *C) { c.Assert(err, ErrorMatches, `yaml: write error: some write error`) // Data not flushed yet } +func (s *S) TestEncoderJSONTags(c *C) { + type T struct { + Z int `json:"a"` + Y int `json:"b,omitempty"` + X int `json:",omitempty"` + W int `json:"-"` + } + var buf bytes.Buffer + enc := yaml.NewEncoder(&buf) + enc.SetStructTagKeys([]string{"json"}) + err := enc.Encode(T{1, 0, 0, 0}) + c.Assert(err, Equals, nil) + err = enc.Close() + c.Assert(err, Equals, nil) + c.Assert(buf.String(), Equals, "a: 1\n") +} + +func (s *S) TestEncoderYAMLAndJSONTags(c *C) { + type T struct { + Z int `yaml:"a" json:"b"` + Y int `yaml:"b" json:"a"` + X int `json:"c"` + } + var buf bytes.Buffer + enc := yaml.NewEncoder(&buf) + enc.SetStructTagKeys([]string{"yaml", "json"}) + err := enc.Encode(T{1, 2, 3}) + c.Assert(err, Equals, nil) + err = enc.Close() + c.Assert(err, Equals, nil) + c.Assert(buf.String(), Equals, "a: 1\nb: 2\nc: 3\n") +} + type errorWriter struct{} func (errorWriter) Write([]byte) (int, error) { diff --git a/yaml.go b/yaml.go index 8cec6da4..33e8f3dc 100644 --- a/yaml.go +++ b/yaml.go @@ -91,8 +91,9 @@ func Unmarshal(in []byte, out interface{}) (err error) { // A Decoder reads and decodes YAML values from an input stream. type Decoder struct { - parser *parser - knownFields bool + parser *parser + knownFields bool + structTagKeys []string } // NewDecoder returns a new decoder that reads from r. @@ -111,6 +112,24 @@ func (dec *Decoder) KnownFields(enable bool) { dec.knownFields = enable } +// SetStructTagKeys changes the tag keys used when decoding. By default +// the "yaml" tag key is used. +// +// The keys are tried in order and the first one that is found in the struct +// tag will be used. +// +// For example: +// +// dec := yaml.NewDecoder(r) +// dec.SetStructTagKeys([]string{"yaml", "json"}) +// dec.Decode(v) +// +// This will use the "yaml" tag key first, and if it's not found it will +// use the "json" tag key. +func (dec *Decoder) SetStructTagKeys(keys []string) { + dec.structTagKeys = keys +} + // Decode reads the next YAML-encoded value from its input // and stores it in the value pointed to by v. // @@ -119,6 +138,7 @@ func (dec *Decoder) KnownFields(enable bool) { func (dec *Decoder) Decode(v interface{}) (err error) { d := newDecoder() d.knownFields = dec.knownFields + d.structTagKeys = dec.structTagKeys defer handleErr(&err) node := dec.parser.parse() if node == nil { @@ -278,6 +298,24 @@ func (e *Encoder) SetIndent(spaces int) { e.encoder.indent = spaces } +// SetStructTagKeys changes the tag keys used when encoding. By default +// the "yaml" tag key is used, but this can be changed to other values. +// +// The keys are tried in order and the first one that is found in the struct +// tag will be used. +// +// For example: +// +// enc := yaml.NewEncoder(w) +// enc.SetStructTagKeys([]string{"yaml", "json"}) +// enc.Encode(v) +// +// This will use the "yaml" tag key first, and if it's not found it will +// use the "json" tag key. +func (e *Encoder) SetStructTagKeys(keys []string) { + e.encoder.structTagKeys = keys +} + // Close closes the encoder by writing any remaining data. // It does not write a stream terminating string "...". func (e *Encoder) Close() (err error) { @@ -518,13 +556,14 @@ type fieldInfo struct { var structMap = make(map[reflect.Type]*structInfo) var fieldMapMutex sync.RWMutex var unmarshalerType reflect.Type +var defaultStructTagKeys = []string{"yaml"} func init() { var v Unmarshaler unmarshalerType = reflect.ValueOf(&v).Elem().Type() } -func getStructInfo(st reflect.Type) (*structInfo, error) { +func getStructInfo(st reflect.Type, structTagKeys []string) (*structInfo, error) { fieldMapMutex.RLock() sinfo, found := structMap[st] fieldMapMutex.RUnlock() @@ -545,7 +584,16 @@ func getStructInfo(st reflect.Type) (*structInfo, error) { info := fieldInfo{Num: i} - tag := field.Tag.Get("yaml") + tag := "" + if structTagKeys == nil { + structTagKeys = defaultStructTagKeys + } + for _, structTagKey := range structTagKeys { + tag = field.Tag.Get(structTagKey) + if tag != "" { + break + } + } if tag == "" && strings.Index(string(field.Tag), ":") < 0 { tag = string(field.Tag) } @@ -592,7 +640,7 @@ func getStructInfo(st reflect.Type) (*structInfo, error) { if reflect.PtrTo(ftype).Implements(unmarshalerType) { inlineUnmarshalers = append(inlineUnmarshalers, []int{i}) } else { - sinfo, err := getStructInfo(ftype) + sinfo, err := getStructInfo(ftype, structTagKeys) if err != nil { return nil, err }