diff --git a/go.mod b/go.mod index 4826710..426d096 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module go.uber.org/mock go 1.20 require ( + github.com/google/go-cmp v0.6.0 golang.org/x/mod v0.15.0 golang.org/x/tools v0.18.0 ) diff --git a/go.sum b/go.sum index b9d67d8..3694c1b 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,7 @@ +github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= +github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/yuin/goldmark v1.4.13 h1:fVcFKWvrslecOb/tg+Cc05dkeYx540o0FuFt3nUVDoE= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/mod v0.15.0 h1:SernR4v+D55NyBH2QiEQrlBAnj1ECL6AGrA5+dPaMY8= diff --git a/gomock/call.go b/gomock/call.go index ef76fd3..6d9af09 100644 --- a/gomock/call.go +++ b/gomock/call.go @@ -19,6 +19,8 @@ import ( "reflect" "strconv" "strings" + + "github.com/google/go-cmp/cmp" ) // Call represents an expected call to a mock. @@ -42,11 +44,13 @@ type Call struct { // can set the return values by returning a non-nil slice. Actions run in the // order they are created. actions []func([]any) []any + + cmpOpts cmp.Options // comparison options } // newCall creates a *Call. It requires the method type in order to support // unexported methods. -func newCall(t TestHelper, receiver any, method string, methodType reflect.Type, args ...any) *Call { +func newCall(t TestHelper, receiver any, method string, methodType reflect.Type, cmpOpts cmp.Options, args ...any) *Call { t.Helper() // TODO: check arity, types. @@ -76,7 +80,8 @@ func newCall(t TestHelper, receiver any, method string, methodType reflect.Type, return rets }} return &Call{t: t, receiver: receiver, method: method, methodType: methodType, - args: mArgs, origin: origin, minCalls: 1, maxCalls: 1, actions: actions} + args: mArgs, origin: origin, minCalls: 1, maxCalls: 1, actions: actions, + cmpOpts: cmpOpts} } // AnyTimes allows the expectation to be called 0 or more times @@ -317,6 +322,30 @@ func (c *Call) String() string { return fmt.Sprintf("%T.%v(%s) %s", c.receiver, c.method, arguments, c.origin) } +func (c *Call) matchError(m Matcher, arg any) error { + if g, ok := m.(GotFormatter); ok { + return fmt.Errorf( + "\nGot: %v\nWant: %v", + g.Got(arg), m, + ) + } + if d, ok := m.(Differ); ok { + diff := d.Diff(arg, c.cmpOpts...) + // Recover if the diff is empty, implying the match failed on ignored fields. + if diff == "" { + return nil + } + return fmt.Errorf( + "\nDiff (-want +got): %s", + diff, + ) + } + return fmt.Errorf( + "\nGot: %v\nWant: %v", + formatGottenArg(m, arg), m, + ) +} + // Tests if the given call matches the expected call. // If yes, returns nil. If no, returns error with message explaining why it does not match. func (c *Call) matches(args []any) error { @@ -327,11 +356,9 @@ func (c *Call) matches(args []any) error { } for i, m := range c.args { - if !m.Matches(args[i]) { - return fmt.Errorf( - "expected call at %s doesn't match the argument at index %d.\nGot: %v\nWant: %v", - c.origin, i, formatGottenArg(m, args[i]), m, - ) + arg := args[i] + if !m.Matches(arg) { + return fmt.Errorf("expected call at %s doesn't match the argument at index %d: %w", c.origin, i, c.matchError(m, arg)) } } } else { @@ -349,11 +376,12 @@ func (c *Call) matches(args []any) error { } for i, m := range c.args { + arg := args[i] if i < c.methodType.NumIn()-1 { // Non-variadic args - if !m.Matches(args[i]) { - return fmt.Errorf("expected call at %s doesn't match the argument at index %s.\nGot: %v\nWant: %v", - c.origin, strconv.Itoa(i), formatGottenArg(m, args[i]), m) + if !m.Matches(arg) { + return fmt.Errorf("expected call at %s doesn't match the argument at index %d: %w", + c.origin, i, c.matchError(m, args[i])) } continue } diff --git a/gomock/callset_test.go b/gomock/callset_test.go index d8150c5..6f07518 100644 --- a/gomock/callset_test.go +++ b/gomock/callset_test.go @@ -30,7 +30,7 @@ func TestCallSetAdd(t *testing.T) { numCalls := 10 for i := 0; i < numCalls; i++ { - cs.Add(newCall(t, receiver, method, reflect.TypeOf(receiverType{}.Func))) + cs.Add(newCall(t, receiver, method, reflect.TypeOf(receiverType{}.Func), nil)) } call, err := cs.FindMatch(receiver, method, []any{}) @@ -47,13 +47,13 @@ func TestCallSetAdd_WhenOverridable_ClearsPreviousExpectedAndExhausted(t *testin var receiver any = "TestReceiver" cs := newOverridableCallSet() - cs.Add(newCall(t, receiver, method, reflect.TypeOf(receiverType{}.Func))) + cs.Add(newCall(t, receiver, method, reflect.TypeOf(receiverType{}.Func), nil)) numExpectedCalls := len(cs.expected[callSetKey{receiver, method}]) if numExpectedCalls != 1 { t.Fatalf("Expected 1 expected call in callset, got %d", numExpectedCalls) } - cs.Add(newCall(t, receiver, method, reflect.TypeOf(receiverType{}.Func))) + cs.Add(newCall(t, receiver, method, reflect.TypeOf(receiverType{}.Func), nil)) newNumExpectedCalls := len(cs.expected[callSetKey{receiver, method}]) if newNumExpectedCalls != 1 { t.Fatalf("Expected 1 expected call in callset, got %d", newNumExpectedCalls) @@ -100,7 +100,7 @@ func TestCallSetFindMatch(t *testing.T) { method := "TestMethod" args := []any{} - c1 := newCall(t, receiver, method, reflect.TypeOf(receiverType{}.Func)) + c1 := newCall(t, receiver, method, reflect.TypeOf(receiverType{}.Func), nil) cs.exhausted = map[callSetKey][]*Call{ {receiver: receiver, fname: method}: {c1}, } diff --git a/gomock/controller.go b/gomock/controller.go index 40bcdf8..b785f98 100644 --- a/gomock/controller.go +++ b/gomock/controller.go @@ -20,6 +20,8 @@ import ( "reflect" "runtime" "sync" + + "github.com/google/go-cmp/cmp" ) // A TestReporter is something that can be used to report test failures. It @@ -76,6 +78,7 @@ type Controller struct { mu sync.Mutex expectedCalls *callSet finished bool + cmpOpts cmp.Options } // NewController returns a new Controller. It is the preferred way to create a Controller. @@ -121,6 +124,20 @@ func (o overridableExpectationsOption) apply(ctrl *Controller) { ctrl.expectedCalls = newOverridableCallSet() } +type cmpOptions struct { + opts []cmp.Option +} + +func (o cmpOptions) apply(ctrl *Controller) { + ctrl.cmpOpts = o.opts +} + +// WithCmpOpts is a ControllerOption that configures the options to pass to +// cmp.Diff. +func WithCmpOpts(opts ...cmp.Option) cmpOptions { + return cmpOptions{opts: opts} +} + type cancelReporter struct { t TestHelper cancel func() @@ -181,7 +198,7 @@ func (ctrl *Controller) RecordCall(receiver any, method string, args ...any) *Ca func (ctrl *Controller) RecordCallWithMethodType(receiver any, method string, methodType reflect.Type, args ...any) *Call { ctrl.T.Helper() - call := newCall(ctrl.T, receiver, method, methodType, args...) + call := newCall(ctrl.T, receiver, method, methodType, ctrl.cmpOpts, args...) ctrl.mu.Lock() defer ctrl.mu.Unlock() diff --git a/gomock/controller_test.go b/gomock/controller_test.go index 03d9e64..d2d4e2c 100644 --- a/gomock/controller_test.go +++ b/gomock/controller_test.go @@ -20,6 +20,8 @@ import ( "strings" "testing" + "github.com/google/go-cmp/cmp/cmpopts" + "go.uber.org/mock/gomock" ) @@ -74,8 +76,11 @@ func (e *ErrorReporter) assertFatal(fn func(), expectedErrMsgs ...string) { // check the last actualErrMsg, because the previous messages come from previous errors actualErrMsg := e.log[len(e.log)-1] for _, expectedErrMsg := range expectedErrMsgs { - if !strings.Contains(actualErrMsg, expectedErrMsg) { + i := strings.Index(actualErrMsg, expectedErrMsg) + if i == -1 { e.t.Errorf("Error message:\ngot: %q\nwant to contain: %q\n", actualErrMsg, expectedErrMsg) + } else { + actualErrMsg = actualErrMsg[i+len(expectedErrMsg):] } } } @@ -149,8 +154,9 @@ func (s *Subject) VariadicMethod(arg int, vararg ...string) {} // A type purely for ActOnTestStructMethod type TestStruct struct { - Number int - Message string + Number int + Message string + secretMessage string } func (s *Subject) ActOnTestStructMethod(arg TestStruct, arg1 int) int { @@ -171,7 +177,9 @@ func createFixtures(t *testing.T) (reporter *ErrorReporter, ctrl *gomock.Control // Controller. We use it to test that the mock considered tests // successful or failed. reporter = NewErrorReporter(t) - ctrl = gomock.NewController(reporter) + ctrl = gomock.NewController( + reporter, gomock.WithCmpOpts(cmpopts.IgnoreUnexported(TestStruct{})), + ) return } @@ -298,13 +306,13 @@ func TestUnexpectedArgValue_FirstArg(t *testing.T) { // the method argument (of TestStruct type) has 1 unexpected value (for the Message field) ctrl.Call(subject, "ActOnTestStructMethod", TestStruct{Number: 123, Message: "no message"}, 15) }, "Unexpected call to", "doesn't match the argument at index 0", - "Got: {123 no message} (gomock_test.TestStruct)\nWant: is equal to {123 hello %s} (gomock_test.TestStruct)") + "Diff (-want +got):", "gomock_test.TestStruct{", "Number: 123", "-", "Message: \"hello %s\",", "+", "Message: \"no message\",", "}") reporter.assertFatal(func() { // the method argument (of TestStruct type) has 2 unexpected values (for both fields) ctrl.Call(subject, "ActOnTestStructMethod", TestStruct{Number: 11, Message: "no message"}, 15) }, "Unexpected call to", "doesn't match the argument at index 0", - "Got: {11 no message} (gomock_test.TestStruct)\nWant: is equal to {123 hello %s} (gomock_test.TestStruct)") + "Diff (-want +got):", "gomock_test.TestStruct{", "-", "Number: 123,", "+", "Number: 11,", "-", "Message: \"hello %s\",", "+", "Message: \"no message\",", "}") reporter.assertFatal(func() { // The expected call wasn't made. @@ -323,7 +331,7 @@ func TestUnexpectedArgValue_SecondArg(t *testing.T) { reporter.assertFatal(func() { ctrl.Call(subject, "ActOnTestStructMethod", TestStruct{Number: 123, Message: "hello"}, 3) }, "Unexpected call to", "doesn't match the argument at index 1", - "Got: 3 (int)\nWant: is equal to 15 (int)") + "Diff (-want +got):", "int(", "-", "15,", "+", "3,", ")") reporter.assertFatal(func() { // The expected call wasn't made. @@ -742,8 +750,8 @@ func TestVariadicNoMatch(t *testing.T) { ctrl.RecordCall(s, "VariadicMethod", 0) rep.assertFatal(func() { ctrl.Call(s, "VariadicMethod", 1) - }, "expected call at", "doesn't match the argument at index 0", - "Got: 1 (int)\nWant: is equal to 0 (int)") + }, "expected call at", "doesn't match the argument at index 0:", + "Diff (-want +got):", "int(", "-", "0,", "+", "1,", ")") ctrl.Call(s, "VariadicMethod", 0) } diff --git a/gomock/matchers.go b/gomock/matchers.go index d0590d0..c03e2cc 100644 --- a/gomock/matchers.go +++ b/gomock/matchers.go @@ -19,6 +19,8 @@ import ( "reflect" "regexp" "strings" + + "github.com/google/go-cmp/cmp" ) // A Matcher is a representation of a class of values. @@ -31,6 +33,11 @@ type Matcher interface { String() string } +type Differ interface { + // Diff shows the difference between the value and x. + Diff(x interface{}, opts ...cmp.Option) string +} + // WantFormatter modifies the given Matcher's String() method to the given // Stringer. This allows for control on how the "Want" is formatted when // printing . @@ -94,6 +101,10 @@ func (anyMatcher) Matches(any) bool { return true } +func (anyMatcher) Diff(interface{}) string { + return "" +} + func (anyMatcher) String() string { return "is anything" } @@ -132,6 +143,10 @@ func (e eqMatcher) Matches(x any) bool { return false } +func (e eqMatcher) Diff(x interface{}, opts ...cmp.Option) string { + return cmp.Diff(e.x, x, opts...) +} + func (e eqMatcher) String() string { return fmt.Sprintf("is equal to %s (%T)", getString(e.x), e.x) } @@ -153,6 +168,10 @@ func (nilMatcher) Matches(x any) bool { return false } +func (nilMatcher) Diff(x interface{}, opts ...cmp.Option) string { + return cmp.Diff(nil, x, opts...) +} + func (nilMatcher) String() string { return "is nil" } @@ -196,6 +215,10 @@ func (m assignableToTypeOfMatcher) Matches(x any) bool { return reflect.TypeOf(x).AssignableTo(m.targetType) } +func (m assignableToTypeOfMatcher) Diff(x interface{}, opts ...cmp.Option) string { + return cmp.Diff(m.targetType, reflect.TypeOf(x), opts...) +} + func (m assignableToTypeOfMatcher) String() string { return "is assignable to " + m.targetType.Name() } @@ -234,6 +257,18 @@ func (am allMatcher) Matches(x any) bool { return true } +func (am allMatcher) Diff(x interface{}, opts ...cmp.Option) string { + ss := make([]string, 0, len(am.matchers)) + for _, matcher := range am.matchers { + if d, ok := matcher.(Differ); ok { + ss = append(ss, d.Diff(x)) + } else { + ss = append(ss, matcher.String()) + } + } + return strings.Join(ss, "; ") +} + func (am allMatcher) String() string { ss := make([]string, 0, len(am.matchers)) for _, matcher := range am.matchers { @@ -256,6 +291,16 @@ func (m lenMatcher) Matches(x any) bool { } } +func (m lenMatcher) Diff(x interface{}, opts ...cmp.Option) string { + v := reflect.ValueOf(x) + switch v.Kind() { + case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice, reflect.String: + return cmp.Diff(m.i, v.Len(), opts...) + default: + return cmp.Diff(m.i, fmt.Sprintf("invalid: len(%T)", x), opts...) + } +} + func (m lenMatcher) String() string { return fmt.Sprintf("has length %d", m.i) } @@ -310,6 +355,52 @@ func (m inAnyOrderMatcher) Matches(x any) bool { return extraInGiven == 0 && missingFromWanted == 0 } +func (m inAnyOrderMatcher) Diff(x interface{}, opts ...cmp.Option) string { + given, ok := m.prepareValue(x) + if !ok { + return cmp.Diff(m.x, x, opts...) + } + wanted, ok := m.prepareValue(m.x) + if !ok { + return cmp.Diff(m.x, x, opts...) + } + + if given.Len() != wanted.Len() { + return cmp.Diff(m.x, x, opts...) + } + + usedFromGiven := make([]bool, given.Len()) + foundFromWanted := make([]bool, wanted.Len()) + for i := 0; i < wanted.Len(); i++ { + wantedMatcher := Eq(wanted.Index(i).Interface()) + for j := 0; j < given.Len(); j++ { + if usedFromGiven[j] { + continue + } + if wantedMatcher.Matches(given.Index(j).Interface()) { + foundFromWanted[i] = true + usedFromGiven[j] = true + break + } + } + } + + missingFromWanted := 0 + for _, found := range foundFromWanted { + if !found { + missingFromWanted++ + } + } + extraInGiven := 0 + for _, used := range usedFromGiven { + if !used { + extraInGiven++ + } + } + + return cmp.Diff(m.x, x, opts...) +} + func (m inAnyOrderMatcher) prepareValue(x any) (reflect.Value, bool) { xValue := reflect.ValueOf(x) switch xValue.Kind() {