diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 20a1f8887843..16a5cac29e54 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -513,6 +513,11 @@ def test_dtype_codegen(): assert repr(eval(full_name)) == full_name +@pytest.mark.parameterize("dtype", [tl.pointer_type(tl.int8), tl.block_type(tl.int8, [42])]) +def test_dtype_is_hashable(dtype): + hash(dtype) + + # --------------- # test binary ops # --------------- diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 90c361424036..9d97092bfe64 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -594,6 +594,8 @@ def __eq__(self, other: pointer_type) -> bool: def __ne__(self, other: pointer_type) -> bool: return not self.__eq__(other) + __hash__ = dtype.__hash__ + @property def scalar(self): return self @@ -646,6 +648,8 @@ def __eq__(self, other: block_type) -> bool: def __ne__(self, other: block_type) -> bool: return not self.__eq__(other) + __hash__ = dtype.__hash__ + @property def scalar(self): return self.element_ty