From 09eb0eb0156e3234e9cdf9df44071fcede8d05da Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 20 Jan 2025 22:25:14 +0000 Subject: [PATCH] Ensured that dtype subclasses are hashable Python data model requires a class to implement both `__eq__` *and* `__hash__` to be considered hashable. If a class only implements `__eq__`, it gets an auto-generated `__hash__ = None`, which makes it non-hashable. So, as things stand, `tl.pointer_type` and `tl.block_type` instances fail to hash. The fix is to define __hash__ explicitly in the class body, as suggested in [*]: > If a class that overrides __eq__() needs to retain the implementation of > __hash__() from a parent class, the interpreter must be told this explicitly > by setting __hash__ = .__hash__. [*]: https://docs.python.org/3/reference/datamodel.html#object.__hash__ --- python/test/unit/language/test_core.py | 8 ++++++++ python/triton/language/core.py | 4 ++++ 2 files changed, 12 insertions(+) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 20a1f88878437..8a3dc3c3a2a58 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -513,6 +513,14 @@ def test_dtype_codegen(): assert repr(eval(full_name)) == full_name +@pytest.mark.parameterize("dtype", [ + tl.pointer_type(tl.int8), + tl.block_type(tl.int8, [42]) +]) +def test_dtype_is_hashable(dtype): + hash(dtype) + + # --------------- # test binary ops # --------------- diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 90c361424036d..9d97092bfe647 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -594,6 +594,8 @@ def __eq__(self, other: pointer_type) -> bool: def __ne__(self, other: pointer_type) -> bool: return not self.__eq__(other) + __hash__ = dtype.__hash__ + @property def scalar(self): return self @@ -646,6 +648,8 @@ def __eq__(self, other: block_type) -> bool: def __ne__(self, other: block_type) -> bool: return not self.__eq__(other) + __hash__ = dtype.__hash__ + @property def scalar(self): return self.element_ty