diff --git a/mock/mock.go b/mock/mock.go index c95eeeca8..aabc025fe 100644 --- a/mock/mock.go +++ b/mock/mock.go @@ -70,6 +70,9 @@ type Call struct { // decoders. RunFn func(Arguments) + // Holds a handler to a function that will be called before returning. + returnFn func(Arguments) Arguments + // PanicMsg holds msg to be used to mock panic on the function call // if the PanicMsg is set to a non nil string the function call will panic // irrespective of other settings @@ -109,6 +112,7 @@ func (c *Call) Return(returnArguments ...interface{}) *Call { defer c.unlock() c.ReturnArguments = returnArguments + c.returnFn = nil return c } @@ -186,6 +190,19 @@ func (c *Call) Run(fn func(args Arguments)) *Call { return c } +// ReturnFn sets a handler to be called before returning. +// +// Mock.On("MyMethod", arg1, arg2).ReturnFn(func(args Arguments) Arguments { +// return Arguments{args.Get(0) + args.Get(1)} +// }) +func (c *Call) ReturnFn(fn func(args Arguments) Arguments) *Call { + c.lock() + defer c.unlock() + c.returnFn = fn + c.ReturnArguments = nil + return c +} + // Maybe allows the method call to be optional. Not calling an optional method // will not cause an error while asserting expectations func (c *Call) Maybe() *Call { @@ -583,6 +600,14 @@ func (m *Mock) MethodCalled(methodName string, arguments ...interface{}) Argumen returnArgs := call.ReturnArguments m.mutex.Unlock() + m.mutex.Lock() + returnFn := call.returnFn + m.mutex.Unlock() + + if returnFn != nil { + returnArgs = returnFn(arguments) + } + return returnArgs } diff --git a/mock/mock_test.go b/mock/mock_test.go index 3dc9e0b1e..0358ec8a2 100644 --- a/mock/mock_test.go +++ b/mock/mock_test.go @@ -31,7 +31,7 @@ type TestExampleImplementation struct { func (i *TestExampleImplementation) TheExampleMethod(a, b, c int) (int, error) { args := i.Called(a, b, c) - return args.Int(0), errors.New("Whoops") + return args.Int(0), args.Error(1) } type options struct { @@ -889,6 +889,148 @@ func Test_Mock_Return_Run_Out_Of_Order(t *testing.T) { assert.NotNil(t, call.Run) } +func Test_Mock_ReturnFn(t *testing.T) { + + // make a test impl object + var mockedService = new(TestExampleImplementation) + + t.Run("can dynamically set the return values", func(t *testing.T) { + counter := 0 + mockedService.On("TheExampleMethod", Anything, Anything, Anything). + ReturnFn(func(args Arguments) Arguments { + counter++ + a, b, c := args[0].(int), args[1].(int), args[2].(int) + assert.IsType(t, 1, a) + assert.IsType(t, 1, b) + assert.IsType(t, 1, c) + return Arguments{a + b + c, nil} + }). + Twice() + + answer, err := mockedService.TheExampleMethod(2, 4, 5) + assert.NoError(t, err) + assert.Equal(t, 11, answer) + assert.Equal(t, 1, counter) + + answer, err = mockedService.TheExampleMethod(44, 4, 5) + assert.NoError(t, err) + assert.Equal(t, 53, answer) + assert.Equal(t, 2, counter) + }) + + t.Run("handles func(Args) Args style", func(t *testing.T) { + mockedService.On("TheExampleMethod", Anything, Anything, Anything). + ReturnFn(func(args Arguments) Arguments { + return []interface{}{args[0].(int) + 40, fmt.Errorf("hmm")} + }). + Twice() + + answer, err := mockedService.TheExampleMethod(2, 4, 5) + assert.Error(t, err, "hmm") + assert.Equal(t, 42, answer) + + answer, err = mockedService.TheExampleMethod(44, 4, 5) + assert.Error(t, err, "hmm") + assert.Equal(t, 84, answer) + }) + + t.Run("handles pointer input args", func(t *testing.T) { + mockedService.On("TheExampleMethod3", Anything).ReturnFn(func(arguments Arguments) Arguments { + et := arguments[0].(*ExampleType) + if et == nil { + return Arguments{errors.New("error")} + } + return Arguments{nil} + }).Twice() + + err := mockedService.TheExampleMethod3(nil) + assert.Error(t, err) + + err = mockedService.TheExampleMethod3(&ExampleType{}) + assert.NoError(t, err) + }) + + t.Run("handles variadic input args", func(t *testing.T) { + mockedService. + On("TheExampleMethodMixedVariadic", Anything, Anything). + ReturnFn(func(args Arguments) Arguments { + a, b := args[0].(int), args[1].([]int) + var sum = a + for _, v := range b { + sum += v + } + return Arguments{fmt.Errorf("%v", sum)} + }) + + assert.Equal(t, "42", mockedService.TheExampleMethodMixedVariadic(40, 1, 1).Error()) + assert.Equal(t, "40", mockedService.TheExampleMethodMixedVariadic(40).Error()) + }) + + t.Run("allows all of Run and RunWithReturn and Return to be used", func(t *testing.T) { + mockedService.On("TheExampleMethod", Anything, Anything, Anything). + Run(func(args Arguments) { + a := args[0].(int) + assert.IsType(t, 1, a) + }). + ReturnFn(func(args Arguments) Arguments { + a := args[0].(int) + return Arguments{a + 40, fmt.Errorf("hmm")} + }). + Return(80, nil) + + answer, err := mockedService.TheExampleMethod(2, 4, 5) + assert.Equal(t, 80, answer) + assert.NoError(t, err) + }) +} + +func Test_Mock_Return_RespectOrder(t *testing.T) { + tests := []struct { + name string + arrange func() *TestExampleImplementation + expected int + }{ + { + name: "should take the last return value", + arrange: func() *TestExampleImplementation { + m := new(TestExampleImplementation) + m.On("TheExampleMethod", Anything, Anything, Anything).Return(1, nil).Return(2, nil) + return m + }, + expected: 2, + }, + { + name: "should take the last return value with returnFn", + arrange: func() *TestExampleImplementation { + m := new(TestExampleImplementation) + m.On("TheExampleMethod", Anything, Anything, Anything).Return(1, nil).ReturnFn(func(args Arguments) Arguments { return Arguments{2, nil} }) + return m + }, + expected: 2, + }, + { + name: "should take the last return value with returnFn and return", + arrange: func() *TestExampleImplementation { + m := new(TestExampleImplementation) + m.On("TheExampleMethod", Anything, Anything, Anything).ReturnFn(func(args Arguments) Arguments { + return Arguments{1, nil} + }).Return(2, nil) + return m + }, + expected: 2, + }, + } + // run the tests + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + m := test.arrange() + actualResult, actualError := m.TheExampleMethod(0, 0, 0) + assert.NoError(t, actualError) + assert.Equal(t, test.expected, actualResult) + }) + } +} + func Test_Mock_Return_Once(t *testing.T) { t.Parallel() @@ -1426,7 +1568,7 @@ func Test_Mock_Called_For_SetTime_Expectation(t *testing.T) { var mockedService = new(TestExampleImplementation) - mockedService.On("TheExampleMethod", 1, 2, 3).Return(5, "6", true).Times(4) + mockedService.On("TheExampleMethod", 1, 2, 3).Return(5, nil).Times(4) mockedService.TheExampleMethod(1, 2, 3) mockedService.TheExampleMethod(1, 2, 3)