diff --git a/openff/interchange/models.py b/openff/interchange/models.py index 67e9a3c7..cd68240b 100644 --- a/openff/interchange/models.py +++ b/openff/interchange/models.py @@ -45,11 +45,20 @@ class TopologyKey(_BaseModel, abc.ABC): description="The indices of the atoms occupied by this interaction", ) + def _tuple(self) -> tuple[Any, ...]: + """Tuple representation of this key.""" + return tuple(self.atom_indices) + def __hash__(self) -> int: - return hash(tuple(self.atom_indices)) + return hash(self._tuple()) def __eq__(self, other: Any) -> bool: - return self.__hash__() == other.__hash__() + if isinstance(other, tuple): + return self._tuple() == other + elif isinstance(other, TopologyKey): + return self._tuple() == other._tuple() + else: + return NotImplemented def __repr__(self) -> str: return f"{self.__class__.__name__} with atom indices {self.atom_indices}" @@ -73,16 +82,11 @@ class BondKey(TopologyKey): ), ) - def __hash__(self) -> int: + def _tuple(self) -> tuple[int, ...] | tuple[tuple[int, ...], float]: if self.bond_order is None: - return hash(tuple(self.atom_indices)) + return tuple(self.atom_indices) else: - return hash((tuple(self.atom_indices), self.bond_order)) - - def __eq__(self, other) -> bool: - return super().__eq__(other) or ( - self.bond_order is None and other == self.atom_indices - ) + return (tuple(self.atom_indices), float(self.bond_order)) def __repr__(self) -> str: return ( @@ -100,11 +104,8 @@ class AngleKey(TopologyKey): description="The indices of the atoms occupied by this interaction", ) - def __hash__(self) -> int: - return hash(tuple(self.atom_indices)) - - def __eq__(self, other) -> bool: - return super().__eq__(other) or other == self.atom_indices + def _tuple(self) -> tuple[int, ...]: + return tuple(self.atom_indices) class ProperTorsionKey(TopologyKey): @@ -137,18 +138,22 @@ class ProperTorsionKey(TopologyKey): ), ) - def __hash__(self) -> int: - if self.mult is None and self.bond_order is None and self.phase is None: - return hash(tuple(self.atom_indices)) - return hash((tuple(self.atom_indices), self.mult, self.bond_order, self.phase)) - - def __eq__(self, other) -> bool: - return super().__eq__(other) or ( - self.mult is None - and self.bond_order is None - and self.phase is None - and other == tuple(self.atom_indices) - ) + def _tuple( + self, + ) -> ( + tuple[()] + | tuple[int, int, int, int] + | tuple[ + tuple[int, int, int, int] | tuple[()], + int | None, + float | None, + float | None, + ] + ): + if self.mult is None and self.phase is None and self.bond_order is None: + return tuple(self.atom_indices) + else: + return (tuple(self.atom_indices), self.mult, self.phase, self.bond_order) def __repr__(self) -> str: return ( @@ -248,14 +253,12 @@ class VirtualSiteKey(TopologyKey): description="The `match` attribute of the associated virtual site type", ) - def __hash__(self) -> int: - return hash( - ( - self.orientation_atom_indices, - self.name, - self.type, - self.match, - ), + def _tuple(self) -> tuple[tuple[int, ...], str, str, str]: + return ( + self.orientation_atom_indices, + self.name, + self.type, + self.match, )