Skip to content

Commit

Permalink
Merge pull request #89 from vimeo/flatten_mangler_nil_intermediate_types
Browse files Browse the repository at this point in the history
transform: flatten mangler: handle nil subfields
  • Loading branch information
dfinkel authored May 16, 2024
2 parents 6d7d1f9 + b91e19a commit 473d614
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 19 deletions.
56 changes: 39 additions & 17 deletions transform/flatten_mangler.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"strings"

"github.com/fatih/structtag"

"github.com/vimeo/dials/common"
"github.com/vimeo/dials/tagformat/caseconversion"
)
Expand Down Expand Up @@ -206,7 +207,8 @@ func (f *FlattenMangler) getTag(sf *reflect.StructField, tags, flattenedPath []s
func (f *FlattenMangler) Unmangle(sf reflect.StructField, vs []FieldValueTuple) (reflect.Value, error) {

val := reflect.New(sf.Type).Elem()
output, err := populateStruct(val, vs, 0)

output, _, err := populateStruct(val, vs, 0)
if err != nil {
return val, err
}
Expand All @@ -218,18 +220,29 @@ func (f *FlattenMangler) Unmangle(sf reflect.StructField, vs []FieldValueTuple)
return val, nil
}

func isNil(val reflect.Value) bool {
switch val.Kind() {
case reflect.Pointer, reflect.Chan, reflect.Slice, reflect.Map, reflect.Func, reflect.Interface:
return val.IsNil()
default:
return false
}
}

// populateStruct populates the original value with values from the flattend values
func populateStruct(originalVal reflect.Value, vs []FieldValueTuple, inputIndex int) (int, error) {
// bool return value indicates whether any inner fields were non-nil (ignore if error-set)
func populateStruct(originalVal reflect.Value, vs []FieldValueTuple, inputIndex int) (int, bool, error) {
if !originalVal.CanSet() {
return inputIndex, fmt.Errorf("error unmangling %s. Need addressable type, actual %q", originalVal, originalVal.Type().Kind())
return inputIndex, false, fmt.Errorf("error unmangling %s. Need addressable type, actual %q", originalVal, originalVal.Type().Kind())
}

kind, vt := getUnderlyingKindType(originalVal.Type())

anyChildSet := false
switch kind {
case reflect.Struct:
// go through each field if the struct doesn't implement TextUnmarshaler
if vt.Implements(textMReflectType) || reflect.PtrTo(vt).Implements(textMReflectType) {
if vt.Implements(textMReflectType) || reflect.PointerTo(vt).Implements(textMReflectType) {
break
}
// the originalVal is a pointer and to go through the fields, we need
Expand All @@ -247,36 +260,45 @@ func populateStruct(originalVal reflect.Value, vs []FieldValueTuple, inputIndex
switch kind {
case reflect.Struct:
// don't flatten if the struct implements TextUnmarshaler
if t.Implements(textMReflectType) || reflect.PtrTo(t).Implements(textMReflectType) {
if t.Implements(textMReflectType) || reflect.PointerTo(t).Implements(textMReflectType) {
break // break out of the case, still stays within the for loop
}
var err error
inputIndex, err = populateStruct(nestedVal, vs, inputIndex)
var nestedAnySet bool
inputIndex, nestedAnySet, err = populateStruct(nestedVal, vs, inputIndex)
if err != nil {
return inputIndex, err
return inputIndex, false, err
}
anyChildSet = anyChildSet || nestedAnySet
continue
default:
}
if !nestedVal.CanSet() {
return inputIndex, fmt.Errorf("nested value %s under %s cannot be set", nestedVal, originalVal)
return inputIndex, false, fmt.Errorf("nested value %s under %s cannot be set", nestedVal, originalVal)
}

if !vs[inputIndex].Value.Type().AssignableTo(nestedVal.Type()) {
return inputIndex, fmt.Errorf("error unmangling. Expected type %s. Actual type %s", vs[inputIndex].Value.Type(), nestedVal.Type())
return inputIndex, false, fmt.Errorf("error unmangling. Expected type %s. Actual type %s", vs[inputIndex].Value.Type(), nestedVal.Type())
}
if !isNil(vs[inputIndex].Value) {
nestedVal.Set(vs[inputIndex].Value)
anyChildSet = true
}
nestedVal.Set(vs[inputIndex].Value)
inputIndex++
}
setVal.Elem().Set(val)
originalVal.Set(setVal)
return inputIndex, nil
default:
if anyChildSet {
setVal.Elem().Set(val)
originalVal.Set(setVal)
}
return inputIndex, anyChildSet, nil
}
val := vs[inputIndex].Value
if !isNil(val) {
originalVal.Set(val)
anyChildSet = true
}
originalVal.Set(vs[inputIndex].Value)
inputIndex++

return inputIndex, nil
return inputIndex, anyChildSet, nil
}

// ShouldRecurse returns false because Mangle walks through nested structs and doesn't need Transform's recursion
Expand Down
73 changes: 71 additions & 2 deletions transform/flatten_mangler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,17 @@ func TestFlattenMangler(t *testing.T) {
},
modify: func(t testing.TB, val reflect.Value) {},
assertion: func(t testing.TB, i interface{}) {
// should be empty struct since none of the fields are exposed
assert.Equal(t, struct{}{}, *i.(*struct{}))
if i == nil {
t.Error("nil Unmangle output")
}
s, ok := i.(*struct{})
if !ok {
t.Errorf("unexpected type %T; expected *struct{}", i)
return
}
if s != nil {
t.Errorf("non-nil Unmangle output for empty struct (with type %T) %+[1]v", s)
}
},
},
{
Expand Down Expand Up @@ -207,6 +216,66 @@ func TestFlattenMangler(t *testing.T) {
assert.Equal(t, st, i)
},
},
{
name: "nil nested struct",
testStruct: b,
modify: func(t testing.TB, val reflect.Value) {

expectedDialsTags := []string{
"config_field_Name",
"config_field_Foobar_Location",
"config_field_Foobar_Coordinates",
"config_field_Foobar_some_time",
"config_field_AnotherField",
}

expectedFieldTags := []string{
"ConfigField,Name",
"ConfigField,Foobar,Location",
"ConfigField,Foobar,Coordinates",
"ConfigField,Foobar,SomeTime",
"ConfigField,AnotherField",
}

for i := 0; i < val.Type().NumField(); i++ {
f := val.Type().Field(i)
assert.EqualValues(t, expectedDialsTags[i], f.Tag.Get(common.DialsTagName))
assert.EqualValues(t, expectedFieldTags[i], f.Tag.Get(dialsFieldPathTag))
if f.Type.Kind() != reflect.Pointer {
t.Errorf("field %d has kind %s, not %s", i, f.Type.Kind(), reflect.Pointer)
}
}

s1 := "test"
i2 := 42

val.Field(0).Set(reflect.ValueOf(&s1))
val.Field(1).Set(reflect.Zero(reflect.TypeOf((*string)(nil))))
val.Field(2).Set(reflect.Zero(reflect.TypeOf((*int)(nil))))
val.Field(3).Set(reflect.Zero(reflect.TypeOf((*time.Duration)(nil))))
val.Field(4).Set(reflect.ValueOf(&i2))
},
assertion: func(t testing.TB, i interface{}) {
// all the fields are pointerified because of call to Pointerify
s1 := "test"
i2 := 42
b := struct {
Name *string `dials:"Name"`
Foobar *struct {
Location *string `dials:"Location"`
Coordinates *int `dials:"Coordinates"`
SomeTime *time.Duration
} `dials:"Foobar"`
AnotherField *int `dials:"AnotherField"`
}{
Name: &s1,
Foobar: nil,
AnotherField: &i2,
}

assert.EqualValues(t, &b, i)
},
},
{
name: "multilevel nested struct",
testStruct: b,
Expand Down

0 comments on commit 473d614

Please sign in to comment.