Skip to content

Commit

Permalink
Adopt cmp.Diff for showing unmatched arguments.
Browse files Browse the repository at this point in the history
  • Loading branch information
SpencerC committed Feb 22, 2024
1 parent e649d89 commit 6e34dff
Show file tree
Hide file tree
Showing 7 changed files with 159 additions and 16 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module go.uber.org/mock
go 1.20

require (
github.com/google/go-cmp v0.6.0
golang.org/x/mod v0.11.0
golang.org/x/tools v0.2.0
)
Expand Down
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg=
github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/yuin/goldmark v1.4.13 h1:fVcFKWvrslecOb/tg+Cc05dkeYx540o0FuFt3nUVDoE=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU=
Expand Down
31 changes: 27 additions & 4 deletions gomock/call.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ import (
"reflect"
"strconv"
"strings"

"github.com/google/go-cmp/cmp"
)

// Call represents an expected call to a mock.
Expand All @@ -42,11 +44,13 @@ type Call struct {
// can set the return values by returning a non-nil slice. Actions run in the
// order they are created.
actions []func([]any) []any

cmpOpts cmp.Options // comparison options
}

// newCall creates a *Call. It requires the method type in order to support
// unexported methods.
func newCall(t TestHelper, receiver any, method string, methodType reflect.Type, args ...any) *Call {
func newCall(t TestHelper, receiver any, method string, methodType reflect.Type, cmpOpts cmp.Options, args ...any) *Call {
t.Helper()

// TODO: check arity, types.
Expand Down Expand Up @@ -76,7 +80,8 @@ func newCall(t TestHelper, receiver any, method string, methodType reflect.Type,
return rets
}}
return &Call{t: t, receiver: receiver, method: method, methodType: methodType,
args: mArgs, origin: origin, minCalls: 1, maxCalls: 1, actions: actions}
args: mArgs, origin: origin, minCalls: 1, maxCalls: 1, actions: actions,
cmpOpts: cmpOpts}
}

// AnyTimes allows the expectation to be called 0 or more times
Expand Down Expand Up @@ -331,10 +336,28 @@ func (c *Call) matches(args []any) error {
}

for i, m := range c.args {
if !m.Matches(args[i]) {
arg := args[i]
if !m.Matches(arg) {
if g, ok := m.(GotFormatter); ok {
return fmt.Errorf(
"expected call at %s doesn't match the argument at index %d.\nGot: %v\nWant: %v",
c.origin, i, g.Got(arg), m,
)
}
if d, ok := m.(Differ); ok {
diff := d.Diff(arg, c.cmpOpts...)
// Recover if the diff is empty, implying the match failed on ignored fields.
if diff == "" {
return nil
}
return fmt.Errorf(
"expected call at %s doesn't match the argument at index %d.\nDiff (-want +got): %s",
c.origin, i, diff,
)
}
return fmt.Errorf(
"expected call at %s doesn't match the argument at index %d.\nGot: %v\nWant: %v",
c.origin, i, formatGottenArg(m, args[i]), m,
c.origin, i, formatGottenArg(m, arg), m,
)
}
}
Expand Down
8 changes: 4 additions & 4 deletions gomock/callset_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func TestCallSetAdd(t *testing.T) {

numCalls := 10
for i := 0; i < numCalls; i++ {
cs.Add(newCall(t, receiver, method, reflect.TypeOf(receiverType{}.Func)))
cs.Add(newCall(t, receiver, method, reflect.TypeOf(receiverType{}.Func), nil))
}

call, err := cs.FindMatch(receiver, method, []any{})
Expand All @@ -47,13 +47,13 @@ func TestCallSetAdd_WhenOverridable_ClearsPreviousExpectedAndExhausted(t *testin
var receiver any = "TestReceiver"
cs := newOverridableCallSet()

cs.Add(newCall(t, receiver, method, reflect.TypeOf(receiverType{}.Func)))
cs.Add(newCall(t, receiver, method, reflect.TypeOf(receiverType{}.Func), nil))
numExpectedCalls := len(cs.expected[callSetKey{receiver, method}])
if numExpectedCalls != 1 {
t.Fatalf("Expected 1 expected call in callset, got %d", numExpectedCalls)
}

cs.Add(newCall(t, receiver, method, reflect.TypeOf(receiverType{}.Func)))
cs.Add(newCall(t, receiver, method, reflect.TypeOf(receiverType{}.Func), nil))
newNumExpectedCalls := len(cs.expected[callSetKey{receiver, method}])
if newNumExpectedCalls != 1 {
t.Fatalf("Expected 1 expected call in callset, got %d", newNumExpectedCalls)
Expand Down Expand Up @@ -100,7 +100,7 @@ func TestCallSetFindMatch(t *testing.T) {
method := "TestMethod"
args := []any{}

c1 := newCall(t, receiver, method, reflect.TypeOf(receiverType{}.Func))
c1 := newCall(t, receiver, method, reflect.TypeOf(receiverType{}.Func), nil)
cs.exhausted = map[callSetKey][]*Call{
{receiver: receiver, fname: method}: {c1},
}
Expand Down
19 changes: 18 additions & 1 deletion gomock/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import (
"reflect"
"runtime"
"sync"

"github.com/google/go-cmp/cmp"
)

// A TestReporter is something that can be used to report test failures. It
Expand Down Expand Up @@ -76,6 +78,7 @@ type Controller struct {
mu sync.Mutex
expectedCalls *callSet
finished bool
cmpOpts cmp.Options
}

// NewController returns a new Controller. It is the preferred way to create a Controller.
Expand Down Expand Up @@ -121,6 +124,20 @@ func (o overridableExpectationsOption) apply(ctrl *Controller) {
ctrl.expectedCalls = newOverridableCallSet()
}

type cmpOptions struct {
opts []cmp.Option
}

func (o cmpOptions) apply(ctrl *Controller) {
ctrl.cmpOpts = o.opts
}

// WithCmpOpts is a ControllerOption that configures the options to pass to
// cmp.Diff.
func WithCmpOpts(opts ...cmp.Option) cmpOptions {
return cmpOptions{opts: opts}
}

type cancelReporter struct {
t TestHelper
cancel func()
Expand Down Expand Up @@ -181,7 +198,7 @@ func (ctrl *Controller) RecordCall(receiver any, method string, args ...any) *Ca
func (ctrl *Controller) RecordCallWithMethodType(receiver any, method string, methodType reflect.Type, args ...any) *Call {
ctrl.T.Helper()

call := newCall(ctrl.T, receiver, method, methodType, args...)
call := newCall(ctrl.T, receiver, method, methodType, ctrl.cmpOpts, args...)

ctrl.mu.Lock()
defer ctrl.mu.Unlock()
Expand Down
21 changes: 14 additions & 7 deletions gomock/controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"testing"

"go.uber.org/mock/gomock"
"github.com/google/go-cmp/cmp/cmpopts"
)

type ErrorReporter struct {
Expand Down Expand Up @@ -74,8 +75,11 @@ func (e *ErrorReporter) assertFatal(fn func(), expectedErrMsgs ...string) {
// check the last actualErrMsg, because the previous messages come from previous errors
actualErrMsg := e.log[len(e.log)-1]
for _, expectedErrMsg := range expectedErrMsgs {
if !strings.Contains(actualErrMsg, expectedErrMsg) {
i := strings.Index(actualErrMsg, expectedErrMsg)
if i == -1 {
e.t.Errorf("Error message:\ngot: %q\nwant to contain: %q\n", actualErrMsg, expectedErrMsg)
} else {
actualErrMsg = actualErrMsg[i+len(expectedErrMsg):]
}
}
}
Expand Down Expand Up @@ -149,8 +153,9 @@ func (s *Subject) VariadicMethod(arg int, vararg ...string) {}

// A type purely for ActOnTestStructMethod
type TestStruct struct {
Number int
Message string
Number int
Message string
secretMessage string
}

func (s *Subject) ActOnTestStructMethod(arg TestStruct, arg1 int) int {
Expand All @@ -171,7 +176,9 @@ func createFixtures(t *testing.T) (reporter *ErrorReporter, ctrl *gomock.Control
// Controller. We use it to test that the mock considered tests
// successful or failed.
reporter = NewErrorReporter(t)
ctrl = gomock.NewController(reporter)
ctrl = gomock.NewController(
reporter, gomock.WithCmpOpts(cmpopts.IgnoreUnexported(TestStruct{})),
)
return
}

Expand Down Expand Up @@ -298,13 +305,13 @@ func TestUnexpectedArgValue_FirstArg(t *testing.T) {
// the method argument (of TestStruct type) has 1 unexpected value (for the Message field)
ctrl.Call(subject, "ActOnTestStructMethod", TestStruct{Number: 123, Message: "no message"}, 15)
}, "Unexpected call to", "doesn't match the argument at index 0",
"Got: {123 no message} (gomock_test.TestStruct)\nWant: is equal to {123 hello %s} (gomock_test.TestStruct)")
"Diff (-want +got):", "gomock_test.TestStruct{", "Number: 123", "-", "Message: \"hello %s\",", "+", "Message: \"no message\",", "}")

reporter.assertFatal(func() {
// the method argument (of TestStruct type) has 2 unexpected values (for both fields)
ctrl.Call(subject, "ActOnTestStructMethod", TestStruct{Number: 11, Message: "no message"}, 15)
}, "Unexpected call to", "doesn't match the argument at index 0",
"Got: {11 no message} (gomock_test.TestStruct)\nWant: is equal to {123 hello %s} (gomock_test.TestStruct)")
"Diff (-want +got):", "gomock_test.TestStruct{", "-", "Number: 123,", "+", "Number: 11,", "-", "Message: \"hello %s\",", "+", "Message: \"no message\",", "}")

reporter.assertFatal(func() {
// The expected call wasn't made.
Expand All @@ -323,7 +330,7 @@ func TestUnexpectedArgValue_SecondArg(t *testing.T) {
reporter.assertFatal(func() {
ctrl.Call(subject, "ActOnTestStructMethod", TestStruct{Number: 123, Message: "hello"}, 3)
}, "Unexpected call to", "doesn't match the argument at index 1",
"Got: 3 (int)\nWant: is equal to 15 (int)")
"Diff (-want +got):", "int(", "-", "15,", "+", "3,", ")")

reporter.assertFatal(func() {
// The expected call wasn't made.
Expand Down
91 changes: 91 additions & 0 deletions gomock/matchers.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ import (
"reflect"
"regexp"
"strings"

"github.com/google/go-cmp/cmp"
)

// A Matcher is a representation of a class of values.
Expand All @@ -31,6 +33,11 @@ type Matcher interface {
String() string
}

type Differ interface {
// Diff shows the difference between the value and x.
Diff(x interface{}, opts ...cmp.Option) string
}

// WantFormatter modifies the given Matcher's String() method to the given
// Stringer. This allows for control on how the "Want" is formatted when
// printing .
Expand Down Expand Up @@ -94,6 +101,10 @@ func (anyMatcher) Matches(any) bool {
return true
}

func (anyMatcher) Diff(interface{}) string {
return ""
}

func (anyMatcher) String() string {
return "is anything"
}
Expand Down Expand Up @@ -132,6 +143,10 @@ func (e eqMatcher) Matches(x any) bool {
return false
}

func (e eqMatcher) Diff(x interface{}, opts ...cmp.Option) string {
return cmp.Diff(e.x, x, opts...)
}

func (e eqMatcher) String() string {
return fmt.Sprintf("is equal to %s (%T)", getString(e.x), e.x)
}
Expand All @@ -153,6 +168,10 @@ func (nilMatcher) Matches(x any) bool {
return false
}

func (nilMatcher) Diff(x interface{}, opts ...cmp.Option) string {
return cmp.Diff(nil, x, opts...)
}

func (nilMatcher) String() string {
return "is nil"
}
Expand Down Expand Up @@ -196,6 +215,10 @@ func (m assignableToTypeOfMatcher) Matches(x any) bool {
return reflect.TypeOf(x).AssignableTo(m.targetType)
}

func (m assignableToTypeOfMatcher) Diff(x interface{}, opts ...cmp.Option) string {
return cmp.Diff(m.targetType, reflect.TypeOf(x), opts...)
}

func (m assignableToTypeOfMatcher) String() string {
return "is assignable to " + m.targetType.Name()
}
Expand Down Expand Up @@ -234,6 +257,18 @@ func (am allMatcher) Matches(x any) bool {
return true
}

func (am allMatcher) Diff(x interface{}, opts ...cmp.Option) string {
ss := make([]string, 0, len(am.matchers))
for _, matcher := range am.matchers {
if d, ok := matcher.(Differ); ok {
ss = append(ss, d.Diff(x))
} else {
ss = append(ss, matcher.String())
}
}
return strings.Join(ss, "; ")
}

func (am allMatcher) String() string {
ss := make([]string, 0, len(am.matchers))
for _, matcher := range am.matchers {
Expand All @@ -256,6 +291,16 @@ func (m lenMatcher) Matches(x any) bool {
}
}

func (m lenMatcher) Diff(x interface{}, opts ...cmp.Option) string {
v := reflect.ValueOf(x)
switch v.Kind() {
case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice, reflect.String:
return cmp.Diff(m.i, v.Len(), opts...)
default:
return cmp.Diff(m.i, fmt.Sprintf("invalid: len(%T)", x), opts...)
}
}

func (m lenMatcher) String() string {
return fmt.Sprintf("has length %d", m.i)
}
Expand Down Expand Up @@ -310,6 +355,52 @@ func (m inAnyOrderMatcher) Matches(x any) bool {
return extraInGiven == 0 && missingFromWanted == 0
}

func (m inAnyOrderMatcher) Diff(x interface{}, opts ...cmp.Option) string {
given, ok := m.prepareValue(x)
if !ok {
return cmp.Diff(m.x, x, opts...)
}
wanted, ok := m.prepareValue(m.x)
if !ok {
return cmp.Diff(m.x, x, opts...)
}

if given.Len() != wanted.Len() {
return cmp.Diff(m.x, x, opts...)
}

usedFromGiven := make([]bool, given.Len())
foundFromWanted := make([]bool, wanted.Len())
for i := 0; i < wanted.Len(); i++ {
wantedMatcher := Eq(wanted.Index(i).Interface())
for j := 0; j < given.Len(); j++ {
if usedFromGiven[j] {
continue
}
if wantedMatcher.Matches(given.Index(j).Interface()) {
foundFromWanted[i] = true
usedFromGiven[j] = true
break
}
}
}

missingFromWanted := 0
for _, found := range foundFromWanted {
if !found {
missingFromWanted++
}
}
extraInGiven := 0
for _, used := range usedFromGiven {
if !used {
extraInGiven++
}
}

return cmp.Diff(m.x, x, opts...)
}

func (m inAnyOrderMatcher) prepareValue(x any) (reflect.Value, bool) {
xValue := reflect.ValueOf(x)
switch xValue.Kind() {
Expand Down

0 comments on commit 6e34dff

Please sign in to comment.