diff --git a/assert/assertions.go b/assert/assertions.go index 00d97a51c..29fe3956c 100644 --- a/assert/assertions.go +++ b/assert/assertions.go @@ -1793,3 +1793,15 @@ func buildErrorChainString(err error) string { } return chain } + +// Kind check if the given object is of the given type +func Kind(t TestingT, expected reflect.Kind, object interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + + if reflect.TypeOf(object).Kind() == expected { + return true + } + return false +} diff --git a/assert/assertions_test.go b/assert/assertions_test.go index 3bd418d91..72bbee6e0 100644 --- a/assert/assertions_test.go +++ b/assert/assertions_test.go @@ -198,10 +198,14 @@ func TestEqual(t *testing.T) { {uint64(123), uint64(123), true, ""}, {myType("1"), myType("1"), true, ""}, {&struct{}{}, &struct{}{}, true, "pointer equality is based on equality of underlying value"}, + {time.Date(2020, 3, 1, 12, 23, 14, 0, time.UTC), time.Date(2020, 3, 1, 12, 23, 14, 0, time.UTC), true, ""}, + {time.Date(2020, 6, 3, 12, 23, 14, 0, time.UTC), time.Date(2020, 6, 3, 12, 23, 14, 0, time.UTC), true, ""}, // Not expected to be equal {m["bar"], "something", false, ""}, {myType("1"), myType("2"), false, ""}, + {time.Date(2020, 3, 1, 12, 23, 14, 0, time.UTC), time.Date(2020, 5, 1, 12, 23, 14, 0, time.UTC), true, ""}, + {time.Date(2020, 6, 3, 12, 23, 14, 0, time.UTC), time.Date(2020, 8, 3, 12, 23, 14, 0, time.UTC), true, ""}, // A case that might be confusing, especially with numeric literals {10, uint(10), false, ""}, @@ -2466,3 +2470,42 @@ func TestErrorAs(t *testing.T) { }) } } + +func TestKind(t *testing.T) { + type myType string + + mockT := new(testing.T) + a := 12 + + cases := []struct { + expected reflect.Kind + object interface{} + result bool + remark string + }{ + {reflect.String, "Hello World", true, "1"}, + {reflect.Int, 123, true, "2"}, + {reflect.Array, [6]int{2, 3, 5, 7, 11, 13}, true, "3"}, + {reflect.Func, Kind, true, "4"}, + {reflect.Float64, 0.0345, true, "5"}, + {reflect.Map, make(map[string]int), true, "6"}, + {reflect.Bool, true, true, "7"}, + {reflect.Ptr, &a, true, "8"}, + + // Not expected to be equal + {reflect.String, 13, false, "9"}, + {reflect.Int, [6]int{2, 3, 5, 7, 11, 13}, false, "10"}, + {reflect.Float64, 12, false, "11"}, + {reflect.Bool, make(map[string]int), false, "12"}, + } + + for _, c := range cases { + t.Run(fmt.Sprintf("Kind(%#v, %#v)", c.expected, c.object), func(t *testing.T) { + res := Kind(mockT, c.expected, c.object) + + if res != c.result { + t.Errorf("Kind(%#v, %#v) should return %#v: %s", c.expected, c.object, c.result, c.remark) + } + }) + } +}