Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Collection ergonomics #990

Merged
merged 13 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ md.log
*stderr.txt
*stdout.txt

# LAMMPS
log.lammps
out.lmp
tmp.in

# OS
.DS_Store
.DS_Store?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ def test_getitem(self, sage):
with pytest.raises(LookupError, match="Could not find"):
out["CMAPs"]

first_bondkey = next(iter(out["Bonds"].key_map))
idx_a, idx_b = first_bondkey.atom_indices
assert (
out["Bonds"][idx_a, idx_b]
== out["Bonds"][idx_b, idx_a]
== out["Bonds"].potentials[out["Bonds"].key_map[first_bondkey]]
)

def test_get_parameters(self, sage):
mol = Molecule.from_smiles("CCO")
out = Interchange.from_smirnoff(force_field=sage, topology=[mol])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def generate_v_site_coordinates(
(0, 1, 2, 3),
(
VirtualSiteMocking.sp2_conformer()[0]
+ Quantity( # noqa
+ Quantity(
numpy.array(
[[1.0, numpy.sqrt(2), 1.0], [1.0, -numpy.sqrt(2), -1.0]],
),
Expand Down
59 changes: 59 additions & 0 deletions openff/interchange/_tests/unit_tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from openff.interchange.models import (
AngleKey,
BondKey,
ImproperTorsionKey,
PotentialKey,
Expand Down Expand Up @@ -110,3 +111,61 @@ def test_reprs():
assert "blah" in repr(potential_key)
assert "mult 2" in repr(potential_key)
assert "bond order 1.111" in repr(potential_key)


def test_bondkey_eq_hash():
"""
When __eq__ is true, the hashes must be equal.

The converse is not required in Python; hash collisions between unequal
objects are allowed and will be handled according to __eq__, possibly
with a small runtime cost.
"""

assert BondKey(atom_indices=(1, 3)) == (1, 3)
assert hash(BondKey(atom_indices=(1, 3))) == hash((1, 3))
assert BondKey(atom_indices=(1, 3)) != (1, 4)
assert BondKey(atom_indices=(1, 3)) != (3, 1)
assert BondKey(atom_indices=(1, 3), bond_order=None) == (1, 3)
assert hash(BondKey(atom_indices=(1, 3), bond_order=None)) == hash((1, 3))
assert BondKey(atom_indices=(1, 3), bond_order=None) != ((1, 3), None)
assert BondKey(atom_indices=(1, 3), bond_order=1.5) != (1, 3)
assert BondKey(atom_indices=(1, 3), bond_order=1.5) != ((1, 3), None)
assert BondKey(atom_indices=(1, 3), bond_order=1.5) != ((1, 3), 1.5)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm pretty sure these two should be equal?

    def __hash__(self) -> int:
        if self.bond_order is None:
            return hash(tuple(self.atom_indices))
        return hash((tuple(self.atom_indices), self.bond_order))

Copy link
Collaborator Author

@Yoshanuikabundi Yoshanuikabundi Jul 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this test is correct - Lines 131, 133 and 134 are all extended tuples, which has not been implemented (as we've discussed elsewhere), and in line 133 the bond order is not None so the tuple doesn't match.

The second return line in that hash function is the existing behaviour - BondKey.__eq__ is

def __eq__(self, other) -> bool:
    return super().__eq__(other) or (
        self.bond_order is None and other == self.atom_indices
    )

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahhh I see what's going on, TopologyKey.__eq__ was changed to compare hash functions. I don't think this is generally technically correct as you might get hash collisions - Python uses a 64 bit hash, so there are 2^64 possible hashes, and a typical simulation system might have 2^16 atoms (~65k), so if you need four of them to define a torsion then you might have 2^64 possible atom index tuples... which is really filling up the pigeon holes. Granted torsions tend to be defined between nearby atoms but that's also assuming that Python's hash function is doing a good job of avoiding collisions, and it's had pathological cases in the past.

A better way might be to have a _tuple() function that gives the equivalent tuple, and then define __hash__ as return hash(self._tuple()) and __eq__ as something like return self._tuple() == other._tuple() or self._tuple() == other

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comparing hashes in __eq__ can also run into issues when comparing objects of different types.



def test_anglekey_eq_hash():
"""
When __eq__ is true, the hashes must be equal.

The converse is not required in Python; hash collisions between unequal
objects are allowed and will be handled according to __eq__, possibly
with a small runtime cost.
"""
assert AngleKey(atom_indices=(1, 3, 16)) == (1, 3, 16)
assert hash(AngleKey(atom_indices=(1, 3, 16))) == hash((1, 3, 16))
assert AngleKey(atom_indices=(1, 3, 16)) != (16, 3, 1)
assert AngleKey(atom_indices=(1, 3, 16)) != (1, 3)
assert AngleKey(atom_indices=(1, 3, 16)) != (1, 3, 15)


def test_torsionkey_eq_hash():
"""
When __eq__ is true, the hashes must be equal.

The converse is not required in Python; hash collisions between unequal
objects are allowed and will be handled according to __eq__, possibly
with a small runtime cost.
"""
assert ProperTorsionKey(atom_indices=(1, 2, 3, 4)) == (1, 2, 3, 4)
assert hash(ProperTorsionKey(atom_indices=(1, 2, 3, 4))) == hash((1, 2, 3, 4))
assert ProperTorsionKey(atom_indices=(1, 2, 3, 4)) != (4, 3, 2, 1)
assert ProperTorsionKey(
atom_indices=(1, 2, 3, 4),
mult=None,
phase=None,
bond_order=None,
) == (1, 2, 3, 4)
assert ProperTorsionKey(atom_indices=(1, 2, 3, 4), mult=0) != (1, 2, 3, 4)
assert ProperTorsionKey(atom_indices=(1, 2, 3, 4), phase=1.5) != (1, 2, 3, 4)
assert ProperTorsionKey(atom_indices=(1, 2, 3, 4), bond_order=1.5) != (1, 2, 3, 4)
8 changes: 4 additions & 4 deletions openff/interchange/components/_packmol.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,17 +134,17 @@ def _validate_inputs(
"""
if (
box_vectors is None
and mass_density is None # noqa: W503
and (solute is None or solute.box_vectors is None) # noqa: W503
and mass_density is None
and (solute is None or solute.box_vectors is None)
):
raise PACKMOLValueError(
"One of `box_vectors`, `mass_density`, or"
+ " `solute.box_vectors` must be specified.", # noqa: W503
+ " `solute.box_vectors` must be specified.",
)
if box_vectors is not None and mass_density is not None:
raise PACKMOLValueError(
"`box_vectors` and `mass_density` cannot be specified together;"
+ " choose one or the other.", # noqa: W503
+ " choose one or the other.",
)

if box_vectors is not None and box_vectors.shape != (3, 3):
Expand Down
10 changes: 10 additions & 0 deletions openff/interchange/components/potentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,3 +316,13 @@ def __getattr__(self, attr: str):
return self.key_map
else:
return super().__getattribute__(attr)

def __getitem__(self, key) -> Potential:
if (
isinstance(key, tuple)
and key not in self.key_map
and tuple(reversed(key)) in self.key_map
):
return self.potentials[self.key_map[tuple(reversed(key))]]

return self.potentials[self.key_map[key]]
2 changes: 1 addition & 1 deletion openff/interchange/interop/amber/export/_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,7 @@ def to_prmtop(interchange: "Interchange", file_path: Path | str):
prmtop.write("%FLAG ANGLE_FORCE_CONSTANT\n" "%FORMAT(5E16.8)\n")
angle_k = [
interchange["Angles"].potentials[key].parameters["k"].m_as(kcal_mol_rad2)
/ 2 # noqa
/ 2
for key in potential_key_to_angle_type_mapping
]
text_blob = "".join([f"{val:16.8E}" for val in angle_k])
Expand Down
4 changes: 2 additions & 2 deletions openff/interchange/interop/openmm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,8 @@ def _is_water(molecule: Molecule) -> bool:
# TODO: This should only skip rigid waters, even though HMR or flexible water is questionable
if (
(hydrogen_atom.atomic_number == 1)
and (heavy_atom.atomic_number != 1) # noqa: W503
and not (_is_water(hydrogen_atom.molecule)) # noqa: W503
and (heavy_atom.atomic_number != 1)
and not (_is_water(hydrogen_atom.molecule))
):

hydrogen_index = interchange.topology.atom_index(hydrogen_atom)
Expand Down
31 changes: 27 additions & 4 deletions openff/interchange/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class BondKey(TopologyKey):
A unique identifier of the atoms associated in a bond potential.
"""

atom_indices: tuple[int, ...] = Field(
atom_indices: tuple[int, int] = Field(
description="The indices of the atoms occupied by this interaction",
)

Expand All @@ -71,8 +71,15 @@ class BondKey(TopologyKey):
)

def __hash__(self) -> int:
if self.bond_order is None:
return hash(tuple(self.atom_indices))
return hash((tuple(self.atom_indices), self.bond_order))

def __eq__(self, other) -> bool:
return super().__eq__(other) or (
self.bond_order is None and other == self.atom_indices
)

def __repr__(self) -> str:
return (
f"{self.__class__.__name__} with atom indices {self.atom_indices}"
Expand All @@ -85,17 +92,20 @@ class AngleKey(TopologyKey):
A unique identifier of the atoms associated in an angle potential.
"""

atom_indices: tuple[int, ...] = Field(
atom_indices: tuple[int, int, int] = Field(
description="The indices of the atoms occupied by this interaction",
)

def __eq__(self, other) -> bool:
return super().__eq__(other) or other == self.atom_indices


class ProperTorsionKey(TopologyKey):
"""
A unique identifier of the atoms associated in a proper torsion potential.
"""

atom_indices: tuple[int, ...] = Field(
atom_indices: tuple[int, int, int, int] | tuple[()] = Field(
description="The indices of the atoms occupied by this interaction",
)

Expand All @@ -121,8 +131,18 @@ class ProperTorsionKey(TopologyKey):
)

def __hash__(self) -> int:
if self.mult is None and self.bond_order is None and self.phase is None:
return hash(tuple(self.atom_indices))
return hash((tuple(self.atom_indices), self.mult, self.bond_order, self.phase))

def __eq__(self, other) -> bool:
return super().__eq__(other) or (
self.mult is None
and self.bond_order is None
and self.phase is None
and other == tuple(self.atom_indices)
)

def __repr__(self) -> str:
return (
f"{self.__class__.__name__} with atom indices {self.atom_indices}"
Expand Down Expand Up @@ -154,13 +174,16 @@ class LibraryChargeTopologyKey(DefaultModel):
this_atom_index: int

@property
def atom_indices(self) -> tuple[int, ...]:
def atom_indices(self) -> tuple[int]:
"""Alias for `this_atom_index`."""
return (self.this_atom_index,)

def __hash__(self) -> int:
return hash((self.this_atom_index,))

def __eq__(self, other) -> bool:
return super().__eq__(other) or other == self.this_atom_index


class SingleAtomChargeTopologyKey(LibraryChargeTopologyKey):
"""
Expand Down
4 changes: 2 additions & 2 deletions openff/interchange/smirnoff/_gromacs.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,8 +758,8 @@ def _is_water(molecule: Molecule) -> bool:
# TODO: This should only skip rigid waters, even though HMR or flexible water is questionable
if (
(hydrogen_atom.atomic_number == 1)
and (heavy_atom.atomic_number != 1) # noqa: W503
and not (_is_water(hydrogen_atom.molecule)) # noqa: W503
and (heavy_atom.atomic_number != 1)
and not (_is_water(hydrogen_atom.molecule))
):

# these are molecule indices, whereas in the OpenMM function they are topology indices
Expand Down
4 changes: 2 additions & 2 deletions openff/interchange/smirnoff/_virtual_sites.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,8 +446,8 @@ def _convert_local_coordinates(
# rather than 0 degrees from the z-axis.
vsite_positions = local_coordinate_frames[0] + d * (
cos_theta * cos_phi * local_coordinate_frames[1]
+ sin_theta * cos_phi * local_coordinate_frames[2] # noqa
+ sin_phi * local_coordinate_frames[3] # noqa
+ sin_theta * cos_phi * local_coordinate_frames[2]
+ sin_phi * local_coordinate_frames[3]
)

return vsite_positions
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ exclude_lines =

[flake8]
max-line-length = 119
ignore = E203,B028
ignore = E203,B028,W503
per-file-ignores =
openff/interchange/_tests/unit_tests/test_types.py:F821
openff/interchange/**/__init__.py:F401
Expand Down
Loading