Skip to content

Commit

Permalink
Ensured that dtype subclasses are hashable
Browse files Browse the repository at this point in the history
Python data model requires a class to implement both `__eq__` *and* `__hash__`
to be considered hashable. If a class only implements `__eq__`, it gets an
auto-generated `__hash__ = None`, which makes it non-hashable.

So, as things stand, `tl.pointer_type` and `tl.block_type` instances fail
to hash. The fix is to define __hash__ explicitly in the class body, as
suggested in [*]:

> If a class that overrides __eq__() needs to retain the implementation of
> __hash__() from a parent class, the interpreter must be told this explicitly
> by setting __hash__ = <ParentClass>.__hash__.

[*]: https://docs.python.org/3/reference/datamodel.html#object.__hash__
  • Loading branch information
superbobry committed Jan 20, 2025
1 parent 41ecd1c commit 09eb0eb
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
8 changes: 8 additions & 0 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,14 @@ def test_dtype_codegen():
assert repr(eval(full_name)) == full_name


@pytest.mark.parameterize("dtype", [
tl.pointer_type(tl.int8),
tl.block_type(tl.int8, [42])
])
def test_dtype_is_hashable(dtype):
hash(dtype)


# ---------------
# test binary ops
# ---------------
Expand Down
4 changes: 4 additions & 0 deletions python/triton/language/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,8 @@ def __eq__(self, other: pointer_type) -> bool:
def __ne__(self, other: pointer_type) -> bool:
return not self.__eq__(other)

__hash__ = dtype.__hash__

@property
def scalar(self):
return self
Expand Down Expand Up @@ -646,6 +648,8 @@ def __eq__(self, other: block_type) -> bool:
def __ne__(self, other: block_type) -> bool:
return not self.__eq__(other)

__hash__ = dtype.__hash__

@property
def scalar(self):
return self.element_ty
Expand Down

0 comments on commit 09eb0eb

Please sign in to comment.