Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] feat: support dig out anonymous field's member. #387

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 58 additions & 6 deletions result.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,16 @@ package dig
import (
"fmt"
"reflect"
"strings"

"go.uber.org/dig/internal/digerror"
"go.uber.org/dig/internal/dot"
)

const (
_extraAnonymous = "extra-anonymous"
)

// The result interface represents a result produced by a constructor.
//
// The following implementations exist:
Expand Down Expand Up @@ -74,7 +79,7 @@ func newResult(t reflect.Type, opts resultOptions) (result, error) {
case isError(t):
return nil, newErrInvalidInput("cannot return an error here, return it from the constructor instead", nil)
case IsOut(t):
return newResultObject(t, opts)
return newResultObject(t, opts, false)
case embedsType(t, _outPtrType):
return nil, newErrInvalidInput(fmt.Sprintf(
"cannot build a result object by embedding *dig.Out, embed dig.Out instead: %v embeds *dig.Out", t), nil)
Expand Down Expand Up @@ -353,7 +358,7 @@ func (ro resultObject) DotResult() []*dot.Result {
return types
}

func newResultObject(t reflect.Type, opts resultOptions) (resultObject, error) {
func newResultObject(t reflect.Type, opts resultOptions, anonymous bool) (resultObject, error) {
ro := resultObject{Type: t}
if len(opts.Name) > 0 {
return ro, newErrInvalidInput(fmt.Sprintf(
Expand All @@ -372,19 +377,66 @@ func newResultObject(t reflect.Type, opts resultOptions) (resultObject, error) {
continue
}

if anonymous && !f.IsExported() {
continue
}

rof, err := newResultObjectField(i, f, opts)
if err != nil {
return ro, newErrInvalidInput(fmt.Sprintf("bad field %q of %v", f.Name, t), err)
}

ro.Fields = append(ro.Fields, rof)

if !f.Anonymous || f.Tag.Get(_extraAnonymous) != "true" {
continue
}
if err = extraAnonymous(&ro, &f, &rof, opts); err != nil {
return ro, err
}
}
return ro, nil
}

func extraAnonymous(ro *resultObject, f *reflect.StructField, rof *resultObjectField, opts resultOptions) error {
ft := f.Type
if ft.Kind() == reflect.Pointer {
ft = ft.Elem()
}
subRo, err := newResultObject(ft, opts, true)
if err != nil {
return err
}

for _, subField := range subRo.Fields {
subField.FieldIndices = append(rof.FieldIndices, subField.FieldIndices...)
switch rofResult := rof.Result.(type) {
case resultGrouped:
switch subResult := subField.Result.(type) {
case resultGrouped:
subResult.Group = strings.Join([]string{rofResult.Group, subResult.Group}, ",")
case resultSingle:
subField.Result = resultGrouped{
Group: rofResult.Group,
Flatten: rofResult.Flatten,
Type: subResult.Type,
}
}
}

ro.Fields = append(ro.Fields, subField)
}

return nil
}

func (ro resultObject) Extract(cw containerWriter, decorated bool, v reflect.Value) {
for _, f := range ro.Fields {
f.Result.Extract(cw, decorated, v.Field(f.FieldIndex))
var rv reflect.Value = v
for _, fieldIndex := range f.FieldIndices {
rv = rv.Field(fieldIndex)
}
f.Result.Extract(cw, decorated, rv)
}
}

Expand All @@ -397,7 +449,7 @@ type resultObjectField struct {
//
// We need to track this separately because not all fields of the struct
// map to results.
FieldIndex int
FieldIndices []int

// Result produced by this field.
Result result
Expand All @@ -411,8 +463,8 @@ func (rof resultObjectField) DotResult() []*dot.Result {
// f at index i.
func newResultObjectField(idx int, f reflect.StructField, opts resultOptions) (resultObjectField, error) {
rof := resultObjectField{
FieldName: f.Name,
FieldIndex: idx,
FieldName: f.Name,
FieldIndices: []int{idx},
}

var r result
Expand Down
96 changes: 78 additions & 18 deletions result_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,11 @@ func TestNewResultErrors(t *testing.T) {
}

func TestNewResultObject(t *testing.T) {
type Embed struct {
Writer io.Writer
}

typeOfEmbed := reflect.TypeOf(&Embed{}).Elem()
typeOfReader := reflect.TypeOf((*io.Reader)(nil)).Elem()
typeOfWriter := reflect.TypeOf((*io.Writer)(nil)).Elem()

Expand All @@ -137,14 +142,14 @@ func TestNewResultObject(t *testing.T) {
}{},
wantFields: []resultObjectField{
{
FieldName: "Reader",
FieldIndex: 1,
Result: resultSingle{Type: typeOfReader},
FieldName: "Reader",
FieldIndices: []int{1},
Result: resultSingle{Type: typeOfReader},
},
{
FieldName: "Writer",
FieldIndex: 2,
Result: resultSingle{Type: typeOfWriter},
FieldName: "Writer",
FieldIndices: []int{2},
Result: resultSingle{Type: typeOfWriter},
},
},
},
Expand All @@ -158,14 +163,14 @@ func TestNewResultObject(t *testing.T) {
}{},
wantFields: []resultObjectField{
{
FieldName: "A",
FieldIndex: 1,
Result: resultSingle{Name: "stream-a", Type: typeOfWriter},
FieldName: "A",
FieldIndices: []int{1},
Result: resultSingle{Name: "stream-a", Type: typeOfWriter},
},
{
FieldName: "B",
FieldIndex: 2,
Result: resultSingle{Name: "stream-b", Type: typeOfWriter},
FieldName: "B",
FieldIndices: []int{2},
Result: resultSingle{Name: "stream-b", Type: typeOfWriter},
},
},
},
Expand All @@ -178,17 +183,72 @@ func TestNewResultObject(t *testing.T) {
}{},
wantFields: []resultObjectField{
{
FieldName: "Writer",
FieldIndex: 1,
Result: resultGrouped{Group: "writers", Type: typeOfWriter},
FieldName: "Writer",
FieldIndices: []int{1},
Result: resultGrouped{Group: "writers", Type: typeOfWriter},
},
},
},
{
desc: "anonymous",
give: struct {
Out

Embed
}{},
wantFields: []resultObjectField{
{
FieldName: "Embed",
FieldIndices: []int{1},
Result: resultSingle{Name: "", Type: typeOfEmbed},
},
},
},
{
desc: "anonymous",
give: struct {
Out

Embed `extra-anonymous:"true"`
}{},
wantFields: []resultObjectField{
{
FieldName: "Embed",
FieldIndices: []int{1},
Result: resultSingle{Name: "", Type: typeOfEmbed},
},
{
FieldName: "Writer",
FieldIndices: []int{1, 0},
Result: resultSingle{Name: "", Type: typeOfWriter},
},
},
},
{
desc: "anonymous group",
give: struct {
Out

Embed `extra-anonymous:"true" group:"embed_group"`
}{},
wantFields: []resultObjectField{
{
FieldName: "Embed",
FieldIndices: []int{1},
Result: resultGrouped{Group: "embed_group", Type: typeOfEmbed},
},
{
FieldName: "Writer",
FieldIndices: []int{1, 0},
Result: resultGrouped{Group: "embed_group", Type: typeOfWriter},
},
},
},
}

for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
got, err := newResultObject(reflect.TypeOf(tt.give), tt.opts)
got, err := newResultObject(reflect.TypeOf(tt.give), tt.opts, false)
require.NoError(t, err)
assert.Equal(t, tt.wantFields, got.Fields)
})
Expand Down Expand Up @@ -302,7 +362,7 @@ func TestNewResultObjectErrors(t *testing.T) {

for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
_, err := newResultObject(reflect.TypeOf(tt.give), tt.opts)
_, err := newResultObject(reflect.TypeOf(tt.give), tt.opts, false)
require.Error(t, err)
assert.Contains(t, err.Error(), tt.err)
})
Expand Down Expand Up @@ -404,7 +464,7 @@ func TestWalkResult(t *testing.T) {
}
}{})

ro, err := newResultObject(typ, resultOptions{})
ro, err := newResultObject(typ, resultOptions{}, false)
require.NoError(t, err)

v := fakeResultVisits{
Expand Down