From 039bf038019d4cc7c4ff1b49eed3eeedb94eddb8 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 20 Jan 2025 22:25:14 +0000 Subject: [PATCH] Ensured that dtype subclasses are hashable Python data model requires a class to implement both `__eq__` **and** `__hash__` to be considered hashable. If a class only implements `__eq__`, it gets an auto-generated `__hash__ = None`, which makes it non-hashable. So, as things stand, `tl.pointer_type` and `tl.block_type` instances fail to hash. The fix is to define `__hash__` explicitly in the class body, as suggested in [*]: > If a class that overrides `__eq__` needs to retain the implementation of > `__hash__` from a parent class, the interpreter must be told this explicitly > by setting `__hash__ = .__hash__`. [*]: https://docs.python.org/3/reference/datamodel.html#object.__hash__ --- python/test/unit/language/test_core.py | 5 +++++ python/triton/language/core.py | 4 ++++ 2 files changed, 9 insertions(+) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 20a1f8887843..16a5cac29e54 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -513,6 +513,11 @@ def test_dtype_codegen(): assert repr(eval(full_name)) == full_name +@pytest.mark.parameterize("dtype", [tl.pointer_type(tl.int8), tl.block_type(tl.int8, [42])]) +def test_dtype_is_hashable(dtype): + hash(dtype) + + # --------------- # test binary ops # --------------- diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 90c361424036..9d97092bfe64 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -594,6 +594,8 @@ def __eq__(self, other: pointer_type) -> bool: def __ne__(self, other: pointer_type) -> bool: return not self.__eq__(other) + __hash__ = dtype.__hash__ + @property def scalar(self): return self @@ -646,6 +648,8 @@ def __eq__(self, other: block_type) -> bool: def __ne__(self, other: block_type) -> bool: return not self.__eq__(other) + __hash__ = dtype.__hash__ + @property def scalar(self): return self.element_ty