Skip to content

Commit

Permalink
Allow Topology.add_molecule to take list[Molecule]
Browse files Browse the repository at this point in the history
  • Loading branch information
mattwthompson committed Aug 2, 2024
1 parent c036e7b commit d6223aa
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 6 deletions.
1 change: 1 addition & 0 deletions devtools/conda-envs/test_env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ dependencies:
- pytest-cov
- pytest-xdist
- pytest-rerunfailures
- pytest-timeout
- pyyaml
- toml
- bson
Expand Down
25 changes: 25 additions & 0 deletions openff/toolkit/_tests/test_topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,31 @@ def test_empty(self):
assert not topology.is_periodic
assert len(topology.constrained_atom_pairs.items()) == 0

@pytest.mark.timeout(5)
def test_from_molecule_multiple(self):
"""
Test that add_molecule on a list of many molecules is quick.
See issue #1916
"""
water = create_water()

topology = Topology()

topology.add_molecule(10_000 * [water])

assert topology.n_molecules == 10_000

def test_from_molecule_bad_argument(self):
with pytest.raises(
ValueError,
match="Invalid type.*Topology",
):

topology = Topology()

topology.add_molecule(create_water().to_topology())

def test_reinitialization_box_vectors(self):
topology = Topology()
assert Topology(topology).box_vectors is None
Expand Down
39 changes: 34 additions & 5 deletions openff/toolkit/topology/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
Optional,
TextIO,
Union,
overload,
)

import numpy as np
Expand Down Expand Up @@ -2390,11 +2391,39 @@ def bond(self, bond_topology_index: int) -> "Bond": # type: ignore[return]
return molecule.bond(bond_molecule_index)
this_molecule_start_index += molecule.n_bonds

def add_molecule(self, molecule: MoleculeLike) -> int:
"""Add a copy of the molecule to the topology"""
idx = self._add_molecule_keep_cache(molecule)
self._invalidate_cached_properties()
return idx
@overload
def add_molecule(self, molecule: MoleculeLike) -> int: ...

@overload
def add_molecule(self, molecule: list[MoleculeLike]) -> list[int]: ...

def add_molecule(
self,
molecule: MoleculeLike | list[MoleculeLike],
) -> int | list[int]:
"""Add a molecule or multiple molecules to the topology."""
if isinstance(molecule, list):

indices = [
self._add_molecule_keep_cache(iter_molecule)
for iter_molecule in molecule
]

self._invalidate_cached_properties()

return indices

elif isinstance(molecule, (Molecule, _SimpleMolecule)):

idx = self._add_molecule_keep_cache(molecule)

self._invalidate_cached_properties()

return idx

else:

raise ValueError(f"Invalid type {type(molecule)} for Topology.add_molecule")

def _add_molecule_keep_cache(self, molecule: MoleculeLike) -> int:
self._molecules.append(deepcopy(molecule))
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ exclude_lines =

[flake8]
max-line-length = 119
ignore = E203,W605,W503
ignore = E203,W605,W503,E704
exclude =
openff/toolkit/_tests/_stale_tests.py
per-file-ignores =
Expand Down

0 comments on commit d6223aa

Please sign in to comment.