diff --git a/gomock/matchers.go b/gomock/matchers.go index d0590d0..d52495f 100644 --- a/gomock/matchers.go +++ b/gomock/matchers.go @@ -98,15 +98,19 @@ func (anyMatcher) String() string { return "is anything" } -type condMatcher struct { - fn func(x any) bool +type condMatcher[T any] struct { + fn func(x T) bool } -func (c condMatcher) Matches(x any) bool { - return c.fn(x) +func (c condMatcher[T]) Matches(x any) bool { + typed, ok := x.(T) + if !ok { + return false + } + return c.fn(typed) } -func (condMatcher) String() string { +func (c condMatcher[T]) String() string { return "adheres to a custom condition" } @@ -339,9 +343,9 @@ func Any() Matcher { return anyMatcher{} } // // Example usage: // -// Cond(func(x any){return x.(int) == 1}).Matches(1) // returns true -// Cond(func(x any){return x.(int) == 2}).Matches(1) // returns false -func Cond(fn func(x any) bool) Matcher { return condMatcher{fn} } +// Cond(func(x int){return x == 1}).Matches(1) // returns true +// Cond(func(x int){return x == 2}).Matches(1) // returns false +func Cond[T any](fn func(x T) bool) Matcher { return condMatcher[T]{fn} } // AnyOf returns a composite Matcher that returns true if at least one of the // matchers returns true. diff --git a/gomock/matchers_test.go b/gomock/matchers_test.go index efb03ee..1f5d58c 100644 --- a/gomock/matchers_test.go +++ b/gomock/matchers_test.go @@ -57,7 +57,9 @@ func TestMatchers(t *testing.T) { []e{[]string{"a", "b"}, A{"a", "b"}}, []e{[]string{"a"}, A{"b"}}, }, - {"test Cond", gomock.Cond(func(x any) bool { return x.(B).Name == "Dam" }), []e{B{Name: "Dam"}}, []e{B{Name: "Dave"}}}, + {"test Cond", gomock.Cond(func(x B) bool { return x.Name == "Dam" }), []e{B{Name: "Dam"}}, []e{B{Name: "Dave"}}}, + {"test Cond wrong type", gomock.Cond(func(x B) bool { return x.Name == "Dam" }), []e{B{Name: "Dam"}}, []e{"Dave"}}, + {"test Cond any type", gomock.Cond(func(x any) bool { return x.(B).Name == "Dam" }), []e{B{Name: "Dam"}}, []e{B{Name: "Dave"}}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/sample/user_test.go b/sample/user_test.go index f963651..ebad72d 100644 --- a/sample/user_test.go +++ b/sample/user_test.go @@ -193,12 +193,8 @@ func TestExpectCondForeignFour(t *testing.T) { defer ctrl.Finish() mockIndex := NewMockIndex(ctrl) - mockIndex.EXPECT().ForeignFour(gomock.Cond(func(x any) bool { - four, ok := x.(imp_four.Imp4) - if !ok { - return false - } - return four.Field == "Cool" + mockIndex.EXPECT().ForeignFour(gomock.Cond(func(x imp_four.Imp4) bool { + return x.Field == "Cool" })) mockIndex.ForeignFour(imp_four.Imp4{Field: "Cool"})