diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 20a1f88878437..8a3dc3c3a2a58 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -513,6 +513,14 @@ def test_dtype_codegen(): assert repr(eval(full_name)) == full_name +@pytest.mark.parameterize("dtype", [ + tl.pointer_type(tl.int8), + tl.block_type(tl.int8, [42]) +]) +def test_dtype_is_hashable(dtype): + hash(dtype) + + # --------------- # test binary ops # --------------- diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 90c361424036d..9d97092bfe647 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -594,6 +594,8 @@ def __eq__(self, other: pointer_type) -> bool: def __ne__(self, other: pointer_type) -> bool: return not self.__eq__(other) + __hash__ = dtype.__hash__ + @property def scalar(self): return self @@ -646,6 +648,8 @@ def __eq__(self, other: block_type) -> bool: def __ne__(self, other: block_type) -> bool: return not self.__eq__(other) + __hash__ = dtype.__hash__ + @property def scalar(self): return self.element_ty