diff --git a/gomock/callset.go b/gomock/callset.go index 5649c37..a8eb9b5 100644 --- a/gomock/callset.go +++ b/gomock/callset.go @@ -18,13 +18,15 @@ import ( "bytes" "errors" "fmt" + "sync" ) // callSet represents a set of expected calls, indexed by receiver and method // name. type callSet struct { // Calls that are still expected. - expected map[callSetKey][]*Call + expected map[callSetKey][]*Call + expectedMu *sync.Mutex // Calls that have been exhausted. exhausted map[callSetKey][]*Call } @@ -36,12 +38,20 @@ type callSetKey struct { } func newCallSet() *callSet { - return &callSet{make(map[callSetKey][]*Call), make(map[callSetKey][]*Call)} + return &callSet{ + expected: make(map[callSetKey][]*Call), + expectedMu: &sync.Mutex{}, + exhausted: make(map[callSetKey][]*Call), + } } // Add adds a new expected call. func (cs callSet) Add(call *Call) { key := callSetKey{call.receiver, call.method} + + cs.expectedMu.Lock() + defer cs.expectedMu.Unlock() + m := cs.expected if call.exhausted() { m = cs.exhausted @@ -52,6 +62,10 @@ func (cs callSet) Add(call *Call) { // Remove removes an expected call. func (cs callSet) Remove(call *Call) { key := callSetKey{call.receiver, call.method} + + cs.expectedMu.Lock() + defer cs.expectedMu.Unlock() + calls := cs.expected[key] for i, c := range calls { if c == call { @@ -67,6 +81,9 @@ func (cs callSet) Remove(call *Call) { func (cs callSet) FindMatch(receiver interface{}, method string, args []interface{}) (*Call, error) { key := callSetKey{receiver, method} + cs.expectedMu.Lock() + defer cs.expectedMu.Unlock() + // Search through the expected calls. expected := cs.expected[key] var callsErrors bytes.Buffer @@ -101,6 +118,9 @@ func (cs callSet) FindMatch(receiver interface{}, method string, args []interfac // Failures returns the calls that are not satisfied. func (cs callSet) Failures() []*Call { + cs.expectedMu.Lock() + defer cs.expectedMu.Unlock() + failures := make([]*Call, 0, len(cs.expected)) for _, calls := range cs.expected { for _, call := range calls { @@ -114,6 +134,9 @@ func (cs callSet) Failures() []*Call { // AllExpectedCallsSatisfied returns true in case all expected calls in this callSet are satisfied. func (cs callSet) AllExpectedCallsSatisfied() bool { + cs.expectedMu.Lock() + defer cs.expectedMu.Unlock() + for _, calls := range cs.expected { for _, call := range calls { if !call.satisfied() { diff --git a/gomock/callset_test.go b/gomock/callset_test.go index fe053af..c69c86a 100644 --- a/gomock/callset_test.go +++ b/gomock/callset_test.go @@ -77,7 +77,7 @@ func TestCallSetRemove(t *testing.T) { func TestCallSetFindMatch(t *testing.T) { t.Run("call is exhausted", func(t *testing.T) { - cs := callSet{} + cs := newCallSet() var receiver interface{} = "TestReceiver" method := "TestMethod" args := []interface{}{}