Skip to content

Commit

Permalink
Implement extended tuple TopologyKey equality and make equality testi…
Browse files Browse the repository at this point in the history
…ng more robust
  • Loading branch information
Yoshanuikabundi committed Jul 12, 2024
1 parent 4d3d81a commit 28d2c60
Showing 1 changed file with 38 additions and 35 deletions.
73 changes: 38 additions & 35 deletions openff/interchange/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,20 @@ class TopologyKey(_BaseModel, abc.ABC):
description="The indices of the atoms occupied by this interaction",
)

def _tuple(self) -> tuple[Any, ...]:
"""Tuple representation of this key."""
return tuple(self.atom_indices)

def __hash__(self) -> int:
return hash(tuple(self.atom_indices))
return hash(self._tuple())

def __eq__(self, other: Any) -> bool:
return self.__hash__() == other.__hash__()
if isinstance(other, tuple):
return self._tuple() == other
elif isinstance(other, TopologyKey):
return self._tuple() == other._tuple()
else:
return NotImplemented

def __repr__(self) -> str:
return f"{self.__class__.__name__} with atom indices {self.atom_indices}"
Expand All @@ -73,16 +82,11 @@ class BondKey(TopologyKey):
),
)

def __hash__(self) -> int:
def _tuple(self) -> tuple[int, ...] | tuple[tuple[int, ...], float]:
if self.bond_order is None:
return hash(tuple(self.atom_indices))
return tuple(self.atom_indices)
else:
return hash((tuple(self.atom_indices), self.bond_order))

def __eq__(self, other) -> bool:
return super().__eq__(other) or (
self.bond_order is None and other == self.atom_indices
)
return (tuple(self.atom_indices), float(self.bond_order))

def __repr__(self) -> str:
return (
Expand All @@ -100,11 +104,8 @@ class AngleKey(TopologyKey):
description="The indices of the atoms occupied by this interaction",
)

def __hash__(self) -> int:
return hash(tuple(self.atom_indices))

def __eq__(self, other) -> bool:
return super().__eq__(other) or other == self.atom_indices
def _tuple(self) -> tuple[int, ...]:
return tuple(self.atom_indices)


class ProperTorsionKey(TopologyKey):
Expand Down Expand Up @@ -137,18 +138,22 @@ class ProperTorsionKey(TopologyKey):
),
)

def __hash__(self) -> int:
if self.mult is None and self.bond_order is None and self.phase is None:
return hash(tuple(self.atom_indices))
return hash((tuple(self.atom_indices), self.mult, self.bond_order, self.phase))

def __eq__(self, other) -> bool:
return super().__eq__(other) or (
self.mult is None
and self.bond_order is None
and self.phase is None
and other == tuple(self.atom_indices)
)
def _tuple(
self,
) -> (
tuple[()]
| tuple[int, int, int, int]
| tuple[
tuple[int, int, int, int] | tuple[()],
int | None,
float | None,
float | None,
]
):
if self.mult is None and self.phase is None and self.bond_order is None:
return tuple(self.atom_indices)
else:
return (tuple(self.atom_indices), self.mult, self.phase, self.bond_order)

def __repr__(self) -> str:
return (
Expand Down Expand Up @@ -248,14 +253,12 @@ class VirtualSiteKey(TopologyKey):
description="The `match` attribute of the associated virtual site type",
)

def __hash__(self) -> int:
return hash(
(
self.orientation_atom_indices,
self.name,
self.type,
self.match,
),
def _tuple(self) -> tuple[tuple[int, ...], str, str, str]:
return (
self.orientation_atom_indices,
self.name,
self.type,
self.match,
)


Expand Down

0 comments on commit 28d2c60

Please sign in to comment.