Skip to content

Commit

Permalink
refactor: make expected map safe for concurrent usage
Browse files Browse the repository at this point in the history
  • Loading branch information
FlorianLoch committed Jun 29, 2023
1 parent ecf493f commit 017150d
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 3 deletions.
27 changes: 25 additions & 2 deletions gomock/callset.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@ import (
"bytes"
"errors"
"fmt"
"sync"
)

// callSet represents a set of expected calls, indexed by receiver and method
// name.
type callSet struct {
// Calls that are still expected.
expected map[callSetKey][]*Call
expected map[callSetKey][]*Call
expectedMu *sync.Mutex
// Calls that have been exhausted.
exhausted map[callSetKey][]*Call
}
Expand All @@ -36,12 +38,20 @@ type callSetKey struct {
}

func newCallSet() *callSet {
return &callSet{make(map[callSetKey][]*Call), make(map[callSetKey][]*Call)}
return &callSet{
expected: make(map[callSetKey][]*Call),
expectedMu: &sync.Mutex{},
exhausted: make(map[callSetKey][]*Call),
}
}

// Add adds a new expected call.
func (cs callSet) Add(call *Call) {
key := callSetKey{call.receiver, call.method}

cs.expectedMu.Lock()
defer cs.expectedMu.Unlock()

m := cs.expected
if call.exhausted() {
m = cs.exhausted
Expand All @@ -52,6 +62,10 @@ func (cs callSet) Add(call *Call) {
// Remove removes an expected call.
func (cs callSet) Remove(call *Call) {
key := callSetKey{call.receiver, call.method}

cs.expectedMu.Lock()
defer cs.expectedMu.Unlock()

calls := cs.expected[key]
for i, c := range calls {
if c == call {
Expand All @@ -67,6 +81,9 @@ func (cs callSet) Remove(call *Call) {
func (cs callSet) FindMatch(receiver interface{}, method string, args []interface{}) (*Call, error) {
key := callSetKey{receiver, method}

cs.expectedMu.Lock()
defer cs.expectedMu.Unlock()

// Search through the expected calls.
expected := cs.expected[key]
var callsErrors bytes.Buffer
Expand Down Expand Up @@ -101,6 +118,9 @@ func (cs callSet) FindMatch(receiver interface{}, method string, args []interfac

// Failures returns the calls that are not satisfied.
func (cs callSet) Failures() []*Call {
cs.expectedMu.Lock()
defer cs.expectedMu.Unlock()

failures := make([]*Call, 0, len(cs.expected))
for _, calls := range cs.expected {
for _, call := range calls {
Expand All @@ -114,6 +134,9 @@ func (cs callSet) Failures() []*Call {

// AllExpectedCallsSatisfied returns true in case all expected calls in this callSet are satisfied.
func (cs callSet) AllExpectedCallsSatisfied() bool {
cs.expectedMu.Lock()
defer cs.expectedMu.Unlock()

for _, calls := range cs.expected {
for _, call := range calls {
if !call.satisfied() {
Expand Down
2 changes: 1 addition & 1 deletion gomock/callset_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func TestCallSetRemove(t *testing.T) {

func TestCallSetFindMatch(t *testing.T) {
t.Run("call is exhausted", func(t *testing.T) {
cs := callSet{}
cs := newCallSet()
var receiver interface{} = "TestReceiver"
method := "TestMethod"
args := []interface{}{}
Expand Down

0 comments on commit 017150d

Please sign in to comment.