diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index f6473ce86..e3382dfe4 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -92,7 +92,6 @@ jobs: run: micromamba install "foyer >=0.12.1" -c conda-forge -yq - name: Run tests - if: always() run: | python -m pytest $COV openff/interchange/ \ -r fExs -n logical --durations=10 \ @@ -127,7 +126,6 @@ jobs: python devtools/scripts/molecule-regressions.py - name: Run mypy - continue-on-error: true if: ${{ matrix.python-version == '3.11' }} run: | # As of 01/23, JAX with mypy is too slow to use without a pre-built cache diff --git a/devtools/conda-envs/test_env.yaml b/devtools/conda-envs/test_env.yaml index 3547b5199..4cb1c7dd8 100644 --- a/devtools/conda-envs/test_env.yaml +++ b/devtools/conda-envs/test_env.yaml @@ -7,7 +7,7 @@ dependencies: - numpy - pydantic # OpenFF stack - - openff-toolkit-base >=0.16 + - openff-toolkit-base ~=0.16 - openff-units - ambertools =23 # Needs to be explicitly listed to not be dropped when AmberTools is removed @@ -27,7 +27,7 @@ dependencies: - nbval - nglview # Drivers - - gromacs + - gromacs =2024 - lammps >=2023.08.02 - panedr # Typing diff --git a/openff/interchange/_tests/unit_tests/interop/openmm/test_nonbonded.py b/openff/interchange/_tests/unit_tests/interop/openmm/test_nonbonded.py index 39920da23..fd6019beb 100644 --- a/openff/interchange/_tests/unit_tests/interop/openmm/test_nonbonded.py +++ b/openff/interchange/_tests/unit_tests/interop/openmm/test_nonbonded.py @@ -1,6 +1,7 @@ import pytest from openff.toolkit import Molecule, unit from openff.utilities.testing import skip_if_missing +from pydantic import ValidationError from openff.interchange.exceptions import UnsupportedCutoffMethodError @@ -25,14 +26,20 @@ def test_reaction_field(self, sage, periodic): if periodic: interchange.box = [4, 4, 4] * unit.nanometer interchange["Electrostatics"].periodic_potential = "reaction-field" - else: - interchange["Electrostatics"].nonperiodic_potential = "reaction-field" - with pytest.raises( - UnsupportedCutoffMethodError, - match="Reaction field electrostatics not supported. ", - ): - interchange.to_openmm(combine_nonbonded_forces=False) + with pytest.raises( + UnsupportedCutoffMethodError, + match="Reaction field electrostatics not supported. ", + ): + interchange.to_openmm(combine_nonbonded_forces=False) + + else: + # Not clear that reaction field works with periodic systems, so this can't be set + with pytest.raises( + ValidationError, + match="Input should be 'Coulomb', 'cutoff' or 'no-cutoff'", + ): + interchange["Electrostatics"].nonperiodic_potential = "reaction-field" @skip_if_missing("openmm") diff --git a/openff/interchange/_tests/unit_tests/interop/openmm/test_virtual_sites.py b/openff/interchange/_tests/unit_tests/interop/openmm/test_virtual_sites.py index cbd69359c..8e121b4c6 100644 --- a/openff/interchange/_tests/unit_tests/interop/openmm/test_virtual_sites.py +++ b/openff/interchange/_tests/unit_tests/interop/openmm/test_virtual_sites.py @@ -330,7 +330,6 @@ def test_tip5p_num_exceptions(self, water, tip5p, combine, n_molecules): # Safeguard against some of the behavior seen in #919 for index in range(num_exceptions): p1, p2, *_ = force.getExceptionParameters(index) - print(p1, p2) if sorted([p1, p2]) == [0, 3]: raise Exception( diff --git a/openff/interchange/common/_nonbonded.py b/openff/interchange/common/_nonbonded.py index 9d6d31965..b3529c053 100644 --- a/openff/interchange/common/_nonbonded.py +++ b/openff/interchange/common/_nonbonded.py @@ -93,6 +93,7 @@ class ElectrostaticsCollection(_NonbondedCollection): "Ewald3D-ConductingBoundary", "cutoff", "no-cutoff", + "reaction-field", ] = Field(_PME) nonperiodic_potential: Literal["Coulomb", "cutoff", "no-cutoff"] = Field("Coulomb") exception_potential: Literal["Coulomb"] = Field("Coulomb") @@ -106,7 +107,7 @@ class ElectrostaticsCollection(_NonbondedCollection): _charges_cached: bool = PrivateAttr(default=False) @property - def charges(self) -> dict[TopologyKey, Quantity]: + def charges(self) -> dict[TopologyKey | LibraryChargeTopologyKey, Quantity]: """Get the total partial charge on each atom, including virtual sites.""" if len(self._charges) == 0 or self._charges_cached is False: self._charges = self._get_charges(include_virtual_sites=False) @@ -117,7 +118,7 @@ def charges(self) -> dict[TopologyKey, Quantity]: def _get_charges( self, include_virtual_sites: bool = False, - ) -> dict[TopologyKey, Quantity]: + ) -> dict[TopologyKey | LibraryChargeTopologyKey, Quantity]: if include_virtual_sites: raise NotImplementedError() diff --git a/openff/interchange/components/_packmol.py b/openff/interchange/components/_packmol.py index eafa9798d..91ece3b48 100644 --- a/openff/interchange/components/_packmol.py +++ b/openff/interchange/components/_packmol.py @@ -92,7 +92,7 @@ def _check_add_positive_mass(mass_to_add): ) -def _check_box_shape_shape(box_shape: ArrayLike): +def _check_box_shape_shape(box_shape: NDArray): """Check the .shape of the box_shape argument.""" if box_shape.shape != (3, 3): raise PACKMOLValueError( @@ -531,27 +531,28 @@ def _build_input_file( def _center_topology_at( - center_solute: bool | Literal["BOX_VECS", "ORIGIN", "BRICK"], + center_solute: Literal["NO", "YES", "BOX_VECS", "ORIGIN", "BRICK"], topology: Topology, box_vectors: Quantity, brick_size: Quantity, ) -> Topology: """Return a copy of the topology centered as requested.""" - if isinstance(center_solute, str): - center_solute = center_solute.upper() + _center_solute = center_solute.upper() + topology = Topology(topology) - if center_solute is False: + if _center_solute == "NO": return topology - elif center_solute in [True, "BOX_VECS"]: + elif _center_solute in ["YES", "BOX_VECS"]: new_center = box_vectors.sum(axis=0) / 2.0 - elif center_solute == "ORIGIN": + elif _center_solute == "ORIGIN": new_center = numpy.zeros(3) - elif center_solute == "BRICK": + elif _center_solute == "BRICK": new_center = brick_size / 2.0 else: PACKMOLValueError( - f"center_solute must be a bool, 'BOX_VECS', 'ORIGIN', or 'BRICK', not {center_solute!r}", + "center_solute must be 'NO', 'YES', 'BOX_VECS', 'ORIGIN', or 'BRICK', " + f"not {center_solute!r}", ) positions = topology.get_positions() @@ -569,7 +570,7 @@ def pack_box( box_vectors: Quantity | None = None, mass_density: Quantity | None = None, box_shape: ArrayLike = RHOMBIC_DODECAHEDRON, - center_solute: bool | Literal["BOX_VECS", "ORIGIN", "BRICK"] = False, + center_solute: Literal["NO", "YES", "BOX_VECS", "ORIGIN", "BRICK"] = "NO", working_directory: str | None = None, retain_working_files: bool = False, ) -> Topology: @@ -609,12 +610,12 @@ def pack_box( `_. center_solute - How to center ``solute`` in the simulation box. If ``True`` + How to center ``solute`` in the simulation box. If ``"YES"`` or ``"box_vecs"``, the solute's center of geometry will be placed at the center of the box's parallelopiped representation. If ``"origin"``, the solute will centered at the origin. If ``"brick"``, the solute will be centered in the box's rectangular brick representation. If - ``False`` (the default), the solute will not be moved. + ``"NO"`` (the default), the solute will not be moved. working_directory: str, optional The directory in which to generate the temporary working files. If ``None``, a temporary one will be created. @@ -678,7 +679,7 @@ def pack_box( brick_size = _compute_brick_from_box_vectors(box_vectors) # Center the solute - if center_solute and solute is not None: + if center_solute != "NO" and solute is not None: solute = _center_topology_at( center_solute, solute, @@ -956,5 +957,5 @@ def solvate_topology_nonwater( solute=topology, tolerance=tolerance, box_vectors=box_vectors, - center_solute=True, + center_solute="YES", ) diff --git a/openff/interchange/components/_particles.py b/openff/interchange/components/_particles.py index 266030e89..d023eda1a 100644 --- a/openff/interchange/components/_particles.py +++ b/openff/interchange/components/_particles.py @@ -6,7 +6,7 @@ from openff.toolkit import Quantity -from openff.interchange._annotations import _DistanceQuantity +from openff.interchange._annotations import _DistanceQuantity, _Quantity from openff.interchange.pydantic import _BaseModel @@ -15,15 +15,16 @@ class _VirtualSite(_BaseModel, abc.ABC): distance: _DistanceQuantity orientations: tuple[int, ...] - @abc.abstractproperty + @property def local_frame_weights(self) -> tuple[list[float], ...]: raise NotImplementedError() + @property def local_frame_positions(self) -> Quantity: raise NotImplementedError() @property - def _local_frame_coordinates(self) -> Quantity: + def local_frame_coordinates(self) -> _Quantity: """ Return the position of this virtual site in its local frame in spherical coordinates. diff --git a/openff/interchange/components/interchange.py b/openff/interchange/components/interchange.py index 75d070f78..b41edea7f 100644 --- a/openff/interchange/components/interchange.py +++ b/openff/interchange/components/interchange.py @@ -427,7 +427,7 @@ def to_openmm_system( hydrogen_mass=hydrogen_mass, ) - to_openmm = to_openmm_system + to_openmm = to_openmm_system # type: ignore[pydantic-field] def to_openmm_topology( self, @@ -559,6 +559,8 @@ def to_pdb(self, file_path: Path | str, include_virtual_sites: bool = False): "Positions are required to write a `.pdb` file but found None.", ) + assert self.positions is not None + # TODO: Simply wire `include_virtual_sites` to `to_openmm_{topology|positions}`? if include_virtual_sites: from openff.interchange.interop._virtual_sites import ( diff --git a/openff/interchange/components/mdconfig.py b/openff/interchange/components/mdconfig.py index 38a509f66..506e50035 100644 --- a/openff/interchange/components/mdconfig.py +++ b/openff/interchange/components/mdconfig.py @@ -61,7 +61,7 @@ class MDConfig(_BaseModel): description="The distance at which the switching function is applied", ) coul_method: str = Field( - None, + "Unknown", description="The method used to compute pairwise electrostatic interactions", ) coul_cutoff: _DistanceQuantity = Field( @@ -137,7 +137,7 @@ def apply(self, interchange: "Interchange"): if "Electrostatics" in interchange.collections: electrostatics = interchange["Electrostatics"] if self.coul_method.lower() == "pme": - electrostatics.periodic_potential = _PME # type: ignore[assignment] + electrostatics.periodic_potential = _PME else: electrostatics.periodic_potential = self.coul_method # type: ignore[assignment] electrostatics.cutoff = self.coul_cutoff diff --git a/openff/interchange/components/potentials.py b/openff/interchange/components/potentials.py index c3c368c64..2f099bd56 100644 --- a/openff/interchange/components/potentials.py +++ b/openff/interchange/components/potentials.py @@ -6,6 +6,7 @@ from typing import Annotated, Any, Union import numpy +from numpy.typing import NDArray from openff.toolkit import Quantity from openff.utilities.utilities import has_package, requires_package from pydantic import ( @@ -28,12 +29,12 @@ from openff.interchange.warnings import InterchangeDeprecationWarning if has_package("jax"): - from jax import numpy as jax_numpy - -from numpy.typing import ArrayLike - -if has_package("jax"): + # JAX stubs seem very broken, not adding this to type annotations + # even though many should be NDArray | Array from jax import Array + from jax import numpy as jax_numpy +else: + Array = NDArray def __getattr__(name: str): @@ -102,7 +103,7 @@ def validate_potential_or_wrapped_potential( v: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo, -) -> dict[str, Quantity]: +) -> Potential | WrappedPotential: """Validate the parameters field of a Potential object.""" if info.mode == "json": if "parameters" in v: @@ -110,6 +111,9 @@ def validate_potential_or_wrapped_potential( else: return WrappedPotential.model_validate(v) + else: + raise NotImplementedError(f"Validation mode {info.mode} not implemented.") + PotentialOrWrappedPotential = Annotated[ Union[Potential, WrappedPotential], @@ -133,6 +137,8 @@ def validate_key_map(v: Any, handler, info) -> dict: for key, val in v.items(): val_dict = json.loads(val) + key_class: type[TopologyKey | LibraryChargeTopologyKey] + match val_dict["associated_handler"]: case "Bonds": key_class = BondKey @@ -171,7 +177,11 @@ def validate_key_map(v: Any, handler, info) -> dict: return v -def serialize_key_map(value: dict[str, str], handler, info) -> dict[str, str]: +def serialize_key_map( + value: dict[TopologyKey, PotentialKey], + handler, + info, +) -> dict[str, str]: """Serialize the parameters field of a Potential object.""" if info.mode == "json": return { @@ -217,7 +227,7 @@ def validate_potential_dict( def serialize_potential_dict( - value: dict[str, Quantity], + value: dict[PotentialKey, Potential], handler, info, ) -> dict[str, str]: @@ -227,6 +237,8 @@ def serialize_potential_dict( key.model_dump_json(): value.model_dump_json() for key, value in value.items() } + else: + raise NotImplementedError(f"Serialization mode {info.mode} not implemented.") Potentials = Annotated[ @@ -285,7 +297,7 @@ def _get_parameters(self, atom_indices: tuple[int]) -> dict: def get_force_field_parameters( self, use_jax: bool = False, - ) -> Union["ArrayLike", "Array"]: + ) -> NDArray: """Return a flattened representation of the force field parameters.""" # TODO: Handle WrappedPotential if any( @@ -309,20 +321,20 @@ def get_force_field_parameters( ], ) - def set_force_field_parameters(self, new_p: "ArrayLike") -> None: + def set_force_field_parameters(self, new_p: NDArray) -> None: """Set the force field parameters from a flattened representation.""" mapping = self.get_mapping() - if new_p.shape[0] != len(mapping): # type: ignore + if new_p.shape[0] != len(mapping): raise RuntimeError for potential_key, potential_index in self.get_mapping().items(): potential = self.potentials[potential_key] - if len(new_p[potential_index, :]) != len(potential.parameters): # type: ignore + if len(new_p[potential_index, :]) != len(potential.parameters): raise RuntimeError for parameter_index, parameter_key in enumerate(potential.parameters): parameter_units = potential.parameters[parameter_key].units - modified_parameter = new_p[potential_index, parameter_index] # type: ignore + modified_parameter = new_p[potential_index, parameter_index] self.potentials[potential_key].parameters[parameter_key] = ( modified_parameter * parameter_units @@ -332,7 +344,7 @@ def get_system_parameters( self, p=None, use_jax: bool = False, - ) -> Union["ArrayLike", "Array"]: + ) -> NDArray: """ Return a flattened representation of system parameters. @@ -374,7 +386,7 @@ def parametrize( self, p=None, use_jax: bool = True, - ) -> Union["ArrayLike", "Array"]: + ) -> NDArray: """Return an array of system parameters, given an array of force field parameters.""" if p is None: p = self.get_force_field_parameters(use_jax=use_jax) @@ -391,7 +403,7 @@ def parametrize_partial(self): ) @requires_package("jax") - def get_param_matrix(self) -> Union["Array", "ArrayLike"]: + def get_param_matrix(self) -> Array: """Get a matrix representing the mapping between force field and system parameters.""" from functools import partial @@ -406,7 +418,7 @@ def get_param_matrix(self) -> Union["Array", "ArrayLike"]: jac_parametrize = jax.jacfwd(parametrize_partial) jac_res = jac_parametrize(p) - return jac_res.reshape(-1, p.flatten().shape[0]) # type: ignore[union-attr] + return jac_res.reshape(-1, p.flatten().shape[0]) def __getattr__(self, attr: str): if attr == "slot_map": @@ -429,6 +441,7 @@ def validate_collections( from openff.interchange.smirnoff import ( SMIRNOFFAngleCollection, SMIRNOFFBondCollection, + SMIRNOFFCollection, SMIRNOFFConstraintCollection, SMIRNOFFElectrostaticsCollection, SMIRNOFFImproperTorsionCollection, @@ -437,7 +450,7 @@ def validate_collections( SMIRNOFFVirtualSiteCollection, ) - _class_mapping = { + _class_mapping: dict[str, type[SMIRNOFFCollection]] = { "Bonds": SMIRNOFFBondCollection, "Angles": SMIRNOFFAngleCollection, "Constraints": SMIRNOFFConstraintCollection, @@ -456,6 +469,9 @@ def validate_collections( for collection_name, collection_data in v.items() } + else: + raise NotImplementedError(f"Validation mode {info.mode} not implemented.") + _AnnotatedCollections = Annotated[ dict[str, Collection], diff --git a/openff/interchange/constants.py b/openff/interchange/constants.py index a6ebdc20b..afe3d7a03 100644 --- a/openff/interchange/constants.py +++ b/openff/interchange/constants.py @@ -2,9 +2,11 @@ Commonly-used constants. """ +from typing import Literal + from openff.toolkit import unit -_PME = "Ewald3D-ConductingBoundary" +_PME: Literal["Ewald3D-ConductingBoundary"] = "Ewald3D-ConductingBoundary" kj_mol = unit.Unit("kilojoule / mol") kcal_mol = unit.kilocalorie_per_mole diff --git a/openff/interchange/drivers/lammps.py b/openff/interchange/drivers/lammps.py index 9412a32c3..ec78928b9 100644 --- a/openff/interchange/drivers/lammps.py +++ b/openff/interchange/drivers/lammps.py @@ -56,6 +56,8 @@ def _get_lammps_energies( ) -> dict[str, Quantity]: import lammps + assert interchange.positions is not None + if round_positions is not None: interchange.positions = numpy.round(interchange.positions, round_positions) diff --git a/openff/interchange/drivers/report.py b/openff/interchange/drivers/report.py index 37ce14123..1bd1a3f4d 100644 --- a/openff/interchange/drivers/report.py +++ b/openff/interchange/drivers/report.py @@ -183,8 +183,8 @@ def diff( self["Electrostatics"] and other["Electrostatics"] ) is not None: for key in ("vdW", "Electrostatics"): - energy_differences[key] = self[key] - other[key] - energy_differences[key] = self[key] - other[key] + energy_differences[key] = self[key] - other[key] # type: ignore[operator] + energy_differences[key] = self[key] - other[key] # type: ignore[operator] nonbondeds_processed = True diff --git a/openff/interchange/foyer/_create.py b/openff/interchange/foyer/_create.py index 4aa9a19d6..f18dbd4b0 100644 --- a/openff/interchange/foyer/_create.py +++ b/openff/interchange/foyer/_create.py @@ -17,7 +17,7 @@ FoyerRBImproperHandler, FoyerRBProperHandler, ) -from openff.interchange.models import TopologyKey +from openff.interchange.models import LibraryChargeTopologyKey, TopologyKey if has_package("foyer"): from foyer.forcefield import Forcefield @@ -53,7 +53,7 @@ def _create_interchange( # This block is from a mega merge, unclear if it's still needed for name, handler_class in get_handlers_callable().items(): - interchange.collections[name] = handler_class() + interchange.collections[name] = handler_class(type=name) vdw_handler = interchange["vdW"] vdw_handler.scale_14 = force_field.lj14scale @@ -76,7 +76,9 @@ def _create_interchange( # TODO: Populate .mdconfig, but only after a reasonable number of state mutations have been tested - charges = electrostatics.charges + charges: dict[TopologyKey | LibraryChargeTopologyKey, Quantity] = ( + electrostatics.charges + ) for molecule in interchange.topology.molecules: molecule_charges = [ @@ -85,7 +87,9 @@ def _create_interchange( ].m for atom in molecule.atoms ] - molecule.partial_charges = Quantity( + + # Quantity(list[Quantity]) works ... but is a big magical to mypy + molecule.partial_charges = Quantity( # type: ignore[call-overload] molecule_charges, unit.elementary_charge, ) diff --git a/openff/interchange/foyer/_nonbonded.py b/openff/interchange/foyer/_nonbonded.py index 75a7e2fbd..9873efbe2 100644 --- a/openff/interchange/foyer/_nonbonded.py +++ b/openff/interchange/foyer/_nonbonded.py @@ -8,7 +8,11 @@ from openff.interchange.common._nonbonded import ElectrostaticsCollection, vdWCollection from openff.interchange.components.potentials import Potential from openff.interchange.foyer._base import _copy_params -from openff.interchange.models import PotentialKey, TopologyKey +from openff.interchange.models import ( + LibraryChargeTopologyKey, + PotentialKey, + TopologyKey, +) if has_package("foyer"): from foyer.forcefield import Forcefield @@ -60,7 +64,9 @@ class FoyerElectrostaticsHandler(ElectrostaticsCollection): force_field_key: str = "atoms" cutoff: _DistanceQuantity = 9.0 * unit.angstrom - _charges: dict[TopologyKey, Quantity] = PrivateAttr(default_factory=dict) + _charges: dict[TopologyKey | LibraryChargeTopologyKey, Quantity] = PrivateAttr( + default_factory=dict, + ) def store_charges( self, diff --git a/openff/interchange/foyer/_valence.py b/openff/interchange/foyer/_valence.py index 2817042a3..5d2244f45 100644 --- a/openff/interchange/foyer/_valence.py +++ b/openff/interchange/foyer/_valence.py @@ -22,7 +22,7 @@ class FoyerHarmonicBondHandler(FoyerConnectedAtomsHandler, BondCollection): """Handler storing bond potentials as produced by a Foyer force field.""" type: Literal["Bonds"] = "Bonds" - expression: str = "k/2*(r-length)**2" + expression: Literal["k/2*(r-length)**2"] = "k/2*(r-length)**2" force_field_key: str = "harmonic_bonds" connection_attribute: str = "bonds" @@ -59,7 +59,7 @@ class FoyerHarmonicAngleHandler(FoyerConnectedAtomsHandler, AngleCollection): """Handler storing angle potentials as produced by a Foyer force field.""" type: Literal["Angles"] = "Angles" - expression: str = "k/2*(theta-angle)**2" + expression: Literal["k/2*(theta-angle)**2"] = "k/2*(theta-angle)**2" force_field_key: str = "harmonic_angles" connection_attribute: str = "angles" @@ -139,7 +139,7 @@ def store_matches( class FoyerRBImproperHandler(FoyerRBProperHandler): """Handler storing Ryckaert-Bellemans improper torsion potentials as produced by a Foyer force field.""" - type: Literal["RBImpropers"] = "RBImpropers" + type: Literal["RBImpropers"] = "RBImpropers" # type: ignore[assignment] connection_attribute: str = "impropers" @@ -149,8 +149,8 @@ class FoyerPeriodicProperHandler(FoyerConnectedAtomsHandler, ProperTorsionCollec force_field_key: str = "periodic_propers" connection_attribute: str = "propers" raise_on_missing_params: bool = False - type: str = "ProperTorsions" - expression: str = "k*(1+cos(periodicity*theta-phase))" + type: str = "ProperTorsions" # type: ignore[assignment] + expression: str = "k*(1+cos(periodicity*theta-phase))" # type: ignore[assignment] def get_params_with_units(self, params): """Get the parameters of this handler, tagged with units.""" diff --git a/openff/interchange/interop/_virtual_sites.py b/openff/interchange/interop/_virtual_sites.py index 2633325ad..39ff9e414 100644 --- a/openff/interchange/interop/_virtual_sites.py +++ b/openff/interchange/interop/_virtual_sites.py @@ -3,7 +3,6 @@ """ from collections import defaultdict -from collections.abc import Iterable from typing import DefaultDict import numpy @@ -34,7 +33,7 @@ def _virtual_site_parent_molecule_mapping( A dictionary mapping virtual site keys to the index of the molecule they belong to. """ - mapping = dict() + mapping: dict[VirtualSiteKey, int] = dict() if "VirtualSites" not in interchange.collections: return mapping @@ -164,7 +163,7 @@ def get_positions_with_virtual_sites( def _get_separation_by_atom_indices( interchange: Interchange, - atom_indices: Iterable[int], + atom_indices: tuple[int, ...], prioritize_geometry: bool = False, ) -> Quantity: """ @@ -175,6 +174,8 @@ def _get_separation_by_atom_indices( This is slow, but often necessary for converting virtual site "distances" to weighted averages (unitless) of orientation atom positions. """ + assert interchange.positions is not None + if prioritize_geometry: p1 = interchange.positions[atom_indices[1]] p0 = interchange.positions[atom_indices[0]] @@ -182,41 +183,37 @@ def _get_separation_by_atom_indices( return p1 - p0 if "Constraints" in interchange.collections: - collection = interchange["Constraints"] + constraints = interchange["Constraints"] - for key in collection.key_map: + for key in constraints.key_map: if (key.atom_indices == atom_indices) or ( key.atom_indices[::-1] == atom_indices ): - return collection.potentials[collection.key_map[key]].parameters[ + return constraints.potentials[constraints.key_map[key]].parameters[ "distance" ] if "Bonds" in interchange.collections: - collection = interchange["Bonds"] + bonds = interchange["Bonds"] - for key in collection.key_map: + for key in bonds.key_map: if (key.atom_indices == atom_indices) or ( key.atom_indices[::-1] == atom_indices ): - return collection.potentials[collection.key_map[key]].parameters[ - "length" - ] + return bonds.potentials[bonds.key_map[key]].parameters["length"] # Two heavy atoms may be on opposite ends of an angle, in which case it's still # possible to determine their separation as defined by the geometry of the force field if "Angles" in interchange.collections: - collection = interchange["Angles"] + angles = interchange["Angles"] index0 = atom_indices[0] index1 = atom_indices[1] - for key in collection.key_map: + for key in angles.key_map: if (key.atom_indices[0] == index0 and key.atom_indices[2] == index1) or ( key.atom_indices[2] == index0 and key.atom_indices[0] == index1 ): - gamma = collection.potentials[collection.key_map[key]].parameters[ - "angle" - ] + gamma = angles.potentials[angles.key_map[key]].parameters["angle"] a = _get_separation_by_atom_indices( interchange, diff --git a/openff/interchange/interop/amber/export/_export.py b/openff/interchange/interop/amber/export/_export.py index 818fa4b81..2281fcb81 100644 --- a/openff/interchange/interop/amber/export/_export.py +++ b/openff/interchange/interop/amber/export/_export.py @@ -659,7 +659,8 @@ def to_prmtop(interchange: "Interchange", file_path: Path | str): dihedral_phase: list[int] = list() for key_ in potential_key_to_dihedral_type_mapping: - params = interchange[key_.associated_handler].potentials[key_].parameters # type: ignore + assert key_.associated_handler is not None + params = interchange[key_.associated_handler].potentials[key_].parameters idivf = int(params["idivf"]) if "idivf" in params else 1 dihedral_k.append((params["k"] / idivf).m_as(kcal_mol)) dihedral_periodicity.append(params["periodicity"].m_as(unit.dimensionless)) @@ -755,6 +756,8 @@ def to_prmtop(interchange: "Interchange", file_path: Path | str): _write_text_blob(prmtop, text_blob) if IFBOX == 1: + assert interchange.box is not None + if (interchange.box.m != np.diag(np.diagonal(interchange.box.m))).any(): raise NotImplementedError( "Interchange does not yet support exporting non-rectangular boxes to Amber", @@ -772,7 +775,7 @@ def to_prmtop(interchange: "Interchange", file_path: Path | str): prmtop.write("%FLAG BOX_DIMENSIONS\n" "%FORMAT(5E16.8)\n") box = [90.0] for i in range(3): - box.append(interchange.box[i, i].m_as(unit.angstrom)) # type: ignore + box.append(interchange.box[i, i].m_as(unit.angstrom)) text_blob = "".join([f"{val:16.8E}" for val in box]) _write_text_blob(prmtop, text_blob) @@ -809,6 +812,11 @@ def to_inpcrd(interchange: "Interchange", file_path: Path | str): with open(path, "w") as inpcrd: inpcrd.write(f"\n{n_atoms:5d}{time:15.7e}\n") + if interchange.positions is None: + raise UnsupportedExportError( + "Positions are required to write `.inpcrd` files, found `None`.", + ) + coords = interchange.positions.m_as(unit.angstrom) blob = "".join([f"{val:12.7f}".rjust(12) for val in coords.flatten()]) diff --git a/openff/interchange/interop/gromacs/_interchange.py b/openff/interchange/interop/gromacs/_interchange.py index b59f12233..97da77349 100644 --- a/openff/interchange/interop/gromacs/_interchange.py +++ b/openff/interchange/interop/gromacs/_interchange.py @@ -236,7 +236,7 @@ def to_interchange( _key_assigned = False mult = 0 while not _key_assigned: - topology_key = key_type( + torsion_key: ProperTorsionKey = key_type( atom_indices=( dihedral.atom1 + molecule_start_index - 1, dihedral.atom2 + molecule_start_index - 1, @@ -245,14 +245,14 @@ def to_interchange( ), mult=mult, ) - if topology_key not in collection.key_map: + if torsion_key not in collection.key_map: _key_assigned = True else: mult += 1 potential_key = PotentialKey( - id="-".join(map(str, topology_key.atom_indices)), - mult=topology_key.mult, + id="-".join(map(str, torsion_key.atom_indices)), + mult=torsion_key.mult, associated_handler="ExternalSource", ) @@ -287,7 +287,7 @@ def to_interchange( else: raise NotImplementedError() - collection.key_map.update({topology_key: potential_key}) + collection.key_map.update({torsion_key: potential_key}) collection.potentials.update({potential_key: potential}) molecule_start_index += len(molecule_type.atoms) diff --git a/openff/interchange/interop/gromacs/export/_export.py b/openff/interchange/interop/gromacs/export/_export.py index f39d08371..ec680c1ce 100644 --- a/openff/interchange/interop/gromacs/export/_export.py +++ b/openff/interchange/interop/gromacs/export/_export.py @@ -73,8 +73,8 @@ def _write_atomtypes(self, top, merge_atom_types: bool) -> dict[str, str]: ";type, bondingtype, atomic_number, mass, charge, ptype, sigma, epsilon\n", ) - reduced_atom_types = [] - mapping_to_reduced_atom_types = {} + reduced_atom_types: list[tuple[str, LennardJonesAtomType]] = list() + mapping_to_reduced_atom_types: dict[str, str] = dict() def _is_atom_type_in_list( atom_type, @@ -98,7 +98,9 @@ def _is_atom_type_in_list( return _at_name return False - def _get_new_entry_name(atom_type_list) -> str: + def _get_new_entry_name( + atom_type_list: list[tuple[str, LennardJonesAtomType]], + ) -> str: """ Entry name for atom type to be added. """ @@ -116,13 +118,13 @@ def _get_new_entry_name(atom_type_list) -> str: ) if merge_atom_types: - if _is_atom_type_in_list(atom_type, reduced_atom_types): - mapping_to_reduced_atom_types[atom_type.name] = ( - _is_atom_type_in_list( - atom_type, - reduced_atom_types, - ) - ) + atom_type_is_in_list: bool | str = _is_atom_type_in_list( + atom_type, + reduced_atom_types, + ) + + if isinstance(atom_type_is_in_list, str): + mapping_to_reduced_atom_types[atom_type.name] = atom_type_is_in_list else: _at_name = _get_new_entry_name(reduced_atom_types) reduced_atom_types.append((_at_name, atom_type)) @@ -153,6 +155,7 @@ def _get_new_entry_name(atom_type_list) -> str: f"{atom_type.epsilon.m :.16g}\n", ) top.write("\n") + return mapping_to_reduced_atom_types def _write_moleculetypes( diff --git a/openff/interchange/interop/lammps/export/export.py b/openff/interchange/interop/lammps/export/export.py index 8eb1a551e..a67027a89 100644 --- a/openff/interchange/interop/lammps/export/export.py +++ b/openff/interchange/interop/lammps/export/export.py @@ -18,6 +18,13 @@ def to_lammps(interchange: Interchange, file_path: Path | str): if isinstance(file_path, Path): path = file_path + if interchange.positions is None: + raise UnsupportedExportError( + "Interchange object must have positions to export to LAMMPS data file", + ) + + assert interchange.positions is not None + n_atoms = interchange.topology.n_atoms if "Bonds" in interchange.collections: n_bonds = len(interchange["Bonds"].key_map.keys()) @@ -267,6 +274,8 @@ def _write_improper_coeffs(lmp_file: IO, interchange: Interchange): def _write_atoms(lmp_file: IO, interchange: Interchange, atom_type_map: dict): """Write the Atoms section of a LAMMPS data file.""" + assert interchange.positions is not None + lmp_file.write("\nAtoms\n\n") atom_type_map_inv = dict({v: k for k, v in atom_type_map.items()}) diff --git a/openff/interchange/interop/openmm/_import/_nonbonded.py b/openff/interchange/interop/openmm/_import/_nonbonded.py index 3cbedc79f..a0eefa005 100644 --- a/openff/interchange/interop/openmm/_import/_nonbonded.py +++ b/openff/interchange/interop/openmm/_import/_nonbonded.py @@ -1,15 +1,19 @@ from openff.toolkit import Quantity from openff.interchange.common._nonbonded import ElectrostaticsCollection +from openff.interchange.models import LibraryChargeTopologyKey, TopologyKey class BasicElectrostaticsCollection(ElectrostaticsCollection): """A slightly more complete collection than the base class.""" + _charges: dict[TopologyKey | LibraryChargeTopologyKey, Quantity] = dict() + _charges_cached: bool = False + @property def charges( self, - ) -> dict[int, Quantity]: + ) -> dict[TopologyKey | LibraryChargeTopologyKey, Quantity]: """Get the total partial charge on each atom, including virtual sites.""" if len(self._charges) == 0 or self._charges_cached is False: self._charges = self._get_charges() @@ -17,8 +21,10 @@ def charges( return self._charges - def _get_charges(self): - charges: dict[int, Quantity] = dict() + def _get_charges( # type: ignore[override] + self, + ) -> dict[TopologyKey | LibraryChargeTopologyKey, Quantity]: + charges: dict[TopologyKey | LibraryChargeTopologyKey, Quantity] = dict() for topology_key, potential_key in self.key_map.items(): potential = self.potentials[potential_key] diff --git a/openff/interchange/interop/openmm/_nonbonded.py b/openff/interchange/interop/openmm/_nonbonded.py index 415b1cee4..cba521ef1 100644 --- a/openff/interchange/interop/openmm/_nonbonded.py +++ b/openff/interchange/interop/openmm/_nonbonded.py @@ -220,27 +220,28 @@ def _add_particles_to_system( def _prepare_input_data(interchange: "Interchange") -> _NonbondedData: try: - vdw: "vdWCollection" = interchange["vdW"] + vdw = interchange["vdW"] except LookupError: for collection in interchange.collections.values(): if collection.is_plugin: if collection.acts_as == "vdW": # We can't be completely sure all plugins subclass out of vdWCollection here - vdw = collection # type: ignore[assignment] + assert isinstance(collection, vdWCollection) + vdw = collection break else: - vdw = None # type: ignore[assignment] + vdw = None if vdw: vdw_cutoff: Quantity | None = vdw.cutoff if interchange.box is None: - vdw_method: str | None = vdw.nonperiodic_method.lower() + vdw_method = vdw.nonperiodic_method.lower() else: - vdw_method: str | None = vdw.periodic_method.lower() + vdw_method = vdw.periodic_method.lower() - mixing_rule: str | None = getattr(vdw, "mixing_rule", None) - vdw_expression: str | None = vdw.expression.replace("**", "^") + mixing_rule = getattr(vdw, "mixing_rule", None) + vdw_expression = vdw.expression.replace("**", "^") else: vdw_cutoff = None vdw_method = None @@ -248,9 +249,9 @@ def _prepare_input_data(interchange: "Interchange") -> _NonbondedData: vdw_expression = None try: - electrostatics: "ElectrostaticsCollection" = interchange["Electrostatics"] + electrostatics: ElectrostaticsCollection | None = interchange["Electrostatics"] except LookupError: - electrostatics = None # type: ignore[assignment] + electrostatics = None if electrostatics is None: electrostatics_method: str | None = None @@ -265,12 +266,12 @@ def _prepare_input_data(interchange: "Interchange") -> _NonbondedData: electrostatics_method = getattr(electrostatics, "periodic_potential", _PME) return _NonbondedData( - vdw_collection=vdw, + vdw_collection=vdw, # type: ignore[arg-type] vdw_cutoff=vdw_cutoff, vdw_method=vdw_method, vdw_expression=vdw_expression, mixing_rule=mixing_rule, - electrostatics_collection=electrostatics, + electrostatics_collection=electrostatics, # type: ignore[arg-type] electrostatics_method=electrostatics_method, periodic=interchange.box is None, ) @@ -342,15 +343,15 @@ def _create_single_nonbonded_force( non_bonded_force.setNonbondedMethod(openmm.NonbondedForce.LJPME) non_bonded_force.setEwaldErrorTolerance(ewald_tolerance) - elif data["vdw_method"] == data["electrostatics_method"] == "cutoff": - if data["vdw_cutoff"] != data["electrostatics_collection"].cutoff: + elif data.vdw_method == data.electrostatics_method == "cutoff": + if data.vdw_cutoff != data.electrostatics_collection.cutoff: raise UnsupportedExportError( "If using cutoff vdW and electrostatics, cutoffs must match.", ) non_bonded_force.setNonbondedMethod(openmm.NonbondedForce.CutoffPeriodic) non_bonded_force.setCutoffDistance( - to_openmm_quantity(data["vdw_cutoff"]), + to_openmm_quantity(data.vdw_cutoff), ) else: @@ -624,7 +625,7 @@ def _create_multiple_nonbonded_forces( if vdw.is_plugin: # TODO: Custom mixing rules in plugins is untested vdw_14_force = openmm.CustomBondForce( - _get_scaled_potential_function(data.vdw_expression), + _get_scaled_potential_function(data.vdw_expression), # type: ignore[arg-type] ) vdw_14_force.setName("vdW 1-4 force") @@ -645,7 +646,7 @@ def _create_multiple_nonbonded_forces( vdw_14_force.addGlobalParameter(term, value) else: - vdw_expression: str = data.vdw_expression + vdw_expression = data.vdw_expression vdw_14_force = openmm.CustomBondForce(vdw_expression) vdw_14_force.setName("vdW 1-4 force") @@ -708,7 +709,7 @@ def _create_multiple_nonbonded_forces( eps_14 = (eps1 * eps2) ** 0.5 * vdw_14 else: raise UnsupportedExportError( - f"Unsupported mixing rule: {data['mixing_rule']}", + f"Unsupported mixing rule: {data.mixing_rule}", ) # ... and set the 1-4 interactions @@ -759,7 +760,7 @@ def _create_vdw_force( vdw_expression: str = data.vdw_expression # type: ignore[assignment] mixing_rule_expression: str = _MIXING_RULE_EXPRESSIONS.get( - data.mixing_rule, + data.mixing_rule, # type: ignore[arg-type] "", ) @@ -1044,12 +1045,12 @@ def _set_particle_parameters( else: vdw_force.setParticleParameters(particle_index, [sigma, epsilon]) - partial_charge = partial_charges[virtual_site_key].m_as(unit.e) + partial_charge = partial_charges[virtual_site_key].m_as(unit.e) # type: ignore[index] if electrostatics_force is not None: electrostatics_force.setParticleParameters( particle_index, - partial_charges[virtual_site_key].m_as(unit.e), + partial_charges[virtual_site_key].m_as(unit.e), # type: ignore[index] 0.0, 0.0, ) diff --git a/openff/interchange/interop/openmm/_positions.py b/openff/interchange/interop/openmm/_positions.py index b3bd4f0cf..77e78ee44 100644 --- a/openff/interchange/interop/openmm/_positions.py +++ b/openff/interchange/interop/openmm/_positions.py @@ -17,6 +17,8 @@ def to_openmm_positions( include_virtual_sites: bool = True, ) -> "openmm.unit.Quantity": """Generate an array of positions of all particles, optionally including virtual sites.""" + assert interchange.positions is not None + if include_virtual_sites: from openff.interchange.interop._virtual_sites import ( get_positions_with_virtual_sites, diff --git a/openff/interchange/operations/_combine.py b/openff/interchange/operations/_combine.py index 4548ae931..4e8140597 100644 --- a/openff/interchange/operations/_combine.py +++ b/openff/interchange/operations/_combine.py @@ -12,6 +12,7 @@ SwitchingFunctionMismatchError, UnsupportedCombinationError, ) +from openff.interchange.models import LibraryChargeTopologyKey if TYPE_CHECKING: from openff.interchange.components.interchange import Interchange @@ -92,10 +93,12 @@ def _combine( new_atom_indices = tuple(idx + atom_offset for idx in top_key.atom_indices) new_top_key = top_key.__class__(**top_key.model_dump()) try: - new_top_key.atom_indices = new_atom_indices + new_top_key.atom_indices = new_atom_indices # type: ignore[misc] except (ValueError, AttributeError): assert len(new_atom_indices) == 1 + assert isinstance(new_top_key, LibraryChargeTopologyKey) new_top_key.this_atom_index = new_atom_indices[0] + # If interchange was not created with SMIRNOFF, we need avoid merging potentials with same key if pot_key.associated_handler == "ExternalSource": _mult = 0 @@ -115,7 +118,7 @@ def _combine( # Ensure the charge cache is rebuilt if handler_name == "Electrostatics": - self_handler._charges_cached = False + self_handler._charges_cached = False # type: ignore[attr-defined] self_handler._get_charges() result.collections[handler_name] = self_handler diff --git a/openff/interchange/operations/minimize/openmm.py b/openff/interchange/operations/minimize/openmm.py index d37d1342f..3b7ef0cbf 100644 --- a/openff/interchange/operations/minimize/openmm.py +++ b/openff/interchange/operations/minimize/openmm.py @@ -47,6 +47,8 @@ def minimize_openmm( else: raise MinimizationError("OpenMM Minimization failed.") from error + assert interchange.positions is not None + # Assume that all virtual sites are placed at the _end_, so the 0th through # (number of atoms)th positions are the massive particles return from_openmm( diff --git a/openff/interchange/smirnoff/_base.py b/openff/interchange/smirnoff/_base.py index 4ac65206a..f750de111 100644 --- a/openff/interchange/smirnoff/_base.py +++ b/openff/interchange/smirnoff/_base.py @@ -1,5 +1,6 @@ import abc -from typing import Literal, TypeVar +import builtins +from typing import TypeVar from openff.toolkit import Topology from openff.toolkit.typing.engines.smirnoff.parameters import ( @@ -114,10 +115,12 @@ def _check_all_valence_terms_assigned( class SMIRNOFFCollection(Collection, abc.ABC): """Base class for handlers storing potentials produced by SMIRNOFF force fields.""" - type: Literal["Bonds"] = "Bonds" + type: str = "BASE_CLASS" is_plugin: bool = False + expression: str = "BASE_CLASS" + def modify_openmm_forces(self, *args, **kwargs): """Optionally modify, create, or delete forces. Currently only available to plugins.""" raise NotImplementedError() @@ -213,7 +216,7 @@ def store_potentials(self, parameter_handler: TP): @classmethod def create( - cls, # type[T], + cls: builtins.type[T], parameter_handler: TP, topology: "Topology", ) -> T: @@ -240,6 +243,6 @@ def create( def __repr__(self) -> str: return ( - f"Handler '{self.type}' with expression '{self.expression}', {len(self.key_map)} mapping keys, " + f"Handler '{self.type_}' with expression '{self.expression}', {len(self.key_map)} mapping keys, " f"and {len(self.potentials)} potentials" ) diff --git a/openff/interchange/smirnoff/_create.py b/openff/interchange/smirnoff/_create.py index d0329ef9b..5c38dafca 100644 --- a/openff/interchange/smirnoff/_create.py +++ b/openff/interchange/smirnoff/_create.py @@ -405,7 +405,9 @@ def _plugins( topology=topology, ) except TypeError: - tagnames = [x._TAGNAME for x in collection.allowed_parameter_handlers()] + tagnames = [ + x._TAGNAME for x in collection_class.allowed_parameter_handlers() + ] if len(tagnames) > 1: raise NotImplementedError( @@ -413,14 +415,14 @@ def _plugins( ) try: - collection = collection_class.create( + collection = collection_class.create( # type: ignore[call-arg] parameter_handler=force_field[handler_class._TAGNAME], topology=topology, vdw_collection=interchange[tagnames[0]], electrostatics_collection=interchange["Electrostatics"], ) except TypeError: - collection = collection_class.create( + collection = collection_class.create( # type: ignore[call-arg] parameter_handler=force_field[handler_class._TAGNAME], topology=topology, vdw_collection=interchange[tagnames[0]], diff --git a/openff/interchange/smirnoff/_gromacs.py b/openff/interchange/smirnoff/_gromacs.py index fe683aabe..cb9d1985f 100644 --- a/openff/interchange/smirnoff/_gromacs.py +++ b/openff/interchange/smirnoff/_gromacs.py @@ -187,11 +187,11 @@ def _convert( _partial_charges: dict[int | VirtualSiteKey, float] = dict() # Indexed by particle (atom or virtual site) indices - for key, charge in interchange["Electrostatics"]._get_charges().items(): - if type(key) is TopologyKey: - _partial_charges[key.atom_indices[0]] = charge - elif type(key) is VirtualSiteKey: - _partial_charges[key] = charge + for key_, charge in interchange["Electrostatics"]._get_charges().items(): + if type(key_) is TopologyKey: + _partial_charges[key_.atom_indices[0]] = charge + elif type(key_) is VirtualSiteKey: + _partial_charges[key_] = charge else: raise RuntimeError() diff --git a/openff/interchange/smirnoff/_nonbonded.py b/openff/interchange/smirnoff/_nonbonded.py index 64f2bd171..f10844468 100644 --- a/openff/interchange/smirnoff/_nonbonded.py +++ b/openff/interchange/smirnoff/_nonbonded.py @@ -1,3 +1,4 @@ +import builtins import copy import functools import warnings @@ -146,6 +147,11 @@ class _SMIRNOFFNonbondedCollection(SMIRNOFFCollection, _NonbondedCollection): class SMIRNOFFvdWCollection(vdWCollection, SMIRNOFFCollection): """Handler storing vdW potentials as produced by a SMIRNOFF force field.""" + type: Literal["vdW"] = Field("vdW") + expression: Literal["4*epsilon*((sigma/r)**12-(sigma/r)**6)"] = Field( + "4*epsilon*((sigma/r)**12-(sigma/r)**6)", + ) + @classmethod def allowed_parameter_handlers(cls): """Return a list of allowed types of ParameterHandler classes.""" @@ -193,7 +199,7 @@ def store_potentials(self, parameter_handler: vdWHandler) -> None: @classmethod def create( - cls: type[T], + cls: builtins.type[T], parameter_handler: vdWHandler, topology: Topology, ) -> T: @@ -256,18 +262,23 @@ class SMIRNOFFElectrostaticsCollection(ElectrostaticsCollection, SMIRNOFFCollect rather than having each in their own handler. """ + type: Literal["Electrostatics"] = Field("Electrostatics") + + expression: Literal["coul"] = "coul" + periodic_potential: Literal[ "Ewald3D-ConductingBoundary", "cutoff", "no-cutoff", "reaction-field", ] = Field(_PME) + nonperiodic_potential: Literal[ "Coulomb", "cutoff", "no-cutoff", - "reaction-field", ] = Field("Coulomb") + exception_potential: Literal["Coulomb"] = Field("Coulomb") _charges = PrivateAttr(default_factory=dict) @@ -290,14 +301,14 @@ def supported_parameters(cls): @property def _charges_without_virtual_sites( self, - ) -> dict[TopologyKey, Quantity]: + ) -> dict[TopologyKey | LibraryChargeTopologyKey, Quantity]: """Get the total partial charge on each atom, excluding virtual sites.""" return self._get_charges(include_virtual_sites=False) @property def charges( self, - ) -> dict[TopologyKey | VirtualSiteKey, Quantity]: + ) -> dict[TopologyKey | LibraryChargeTopologyKey, Quantity]: """Get the total partial charge on each atom, including virtual sites.""" if len(self._charges) == 0 or self._charges_cached is False: self._charges = self._get_charges(include_virtual_sites=True) @@ -308,7 +319,7 @@ def charges( def _get_charges( self, include_virtual_sites=True, - ) -> dict[TopologyKey | VirtualSiteKey, Quantity]: + ) -> dict[TopologyKey | LibraryChargeTopologyKey, Quantity]: """Get the total partial charge on each atom or particle.""" # Keyed by index for atoms and by VirtualSiteKey for virtual sites. charges: dict[VirtualSiteKey | int, Quantity] = dict() @@ -408,7 +419,7 @@ def parameter_handler_precedence(cls) -> list[str]: @classmethod def create( - cls: type[T], + cls: builtins.type[T], parameter_handler: Any, topology: Topology, charge_from_molecules=None, @@ -485,8 +496,8 @@ def _library_charge_to_potentials( """ Map a matched library charge parameter to a set of potentials. """ - matches = {} - potentials = {} + matches: dict[LibraryChargeTopologyKey, PotentialKey] = dict() + potentials: dict[PotentialKey, Potential] = dict() for i, (atom_index, charge) in enumerate(zip(atom_indices, parameter.charge)): topology_key = LibraryChargeTopologyKey(this_atom_index=atom_index) @@ -548,7 +559,10 @@ def _find_slot_matches( cls, parameter_handler: Union["LibraryChargeHandler", "ChargeIncrementModelHandler"], unique_molecule: Molecule, - ) -> tuple[dict[TopologyKey, PotentialKey], dict[PotentialKey, Potential]]: + ) -> tuple[ + dict[TopologyKey | SingleAtomChargeTopologyKey, PotentialKey], + dict[PotentialKey, Potential], + ]: """ Construct a slot and potential map for a slot based parameter handler. """ @@ -688,11 +702,14 @@ def _find_reference_matches( cls, parameter_handlers: dict[str, ElectrostaticsHandlerType], unique_molecule: Molecule, - ) -> tuple[dict[TopologyKey, PotentialKey], dict[PotentialKey, Potential]]: + ) -> tuple[ + dict[TopologyKey | SingleAtomChargeTopologyKey, PotentialKey], + dict[PotentialKey, Potential], + ]: """ Construct a slot and potential map for a particular reference molecule and set of parameter handlers. """ - matches: dict[TopologyKey, PotentialKey] = dict() + matches: dict[TopologyKey | SingleAtomChargeTopologyKey, PotentialKey] = dict() potentials: dict[PotentialKey, Potential] = dict() expected_matches = {i for i in range(unique_molecule.n_atoms)} diff --git a/openff/interchange/smirnoff/_valence.py b/openff/interchange/smirnoff/_valence.py index 3fccf3974..e48f1e632 100644 --- a/openff/interchange/smirnoff/_valence.py +++ b/openff/interchange/smirnoff/_valence.py @@ -137,7 +137,7 @@ def store_matches( if self.key_map: # TODO: Should the key_map always be reset, or should we be able to partially # update it? Also Note the duplicated code in the child classes - self.key_map: dict[BondKey, PotentialKey] = dict() # type: ignore[assignment] + self.key_map: dict[BondKey, PotentialKey] = dict() matches = parameter_handler.find_matches(topology) for key, val in matches.items(): parameter: BondHandler.BondType = val.parameter_type @@ -194,7 +194,10 @@ def store_potentials(self, parameter_handler: BondHandler) -> None: smirks = potential_key.id force_field_parameters = parameter_handler.parameters[smirks] + assert isinstance(topology_key, BondKey) + if topology_key.bond_order: + bond_order = topology_key.bond_order if force_field_parameters.k_bondorder: data = force_field_parameters.k_bondorder @@ -497,7 +500,8 @@ def store_matches( """ if self.key_map: - self.key_map: dict[ProperTorsionKey, PotentialKey] = dict() # type: ignore[assignment] + self.key_map: dict[ProperTorsionKey, PotentialKey] = dict() + matches = parameter_handler.find_matches(topology) for key, val in matches.items(): parameter: ProperTorsionHandler.ProperTorsionType = val.parameter_type @@ -559,6 +563,8 @@ def store_potentials(self, parameter_handler: ProperTorsionHandler) -> None: n = potential_key.mult parameter = parameter_handler.parameters[smirks] + assert isinstance(topology_key, ProperTorsionKey) + if topology_key.bond_order: bond_order = topology_key.bond_order data = parameter.k_bondorder[n] @@ -581,7 +587,7 @@ def store_potentials(self, parameter_handler: ProperTorsionHandler) -> None: map_key=map_key, ), ) - potential = WrappedPotential( + potential: Potential | WrappedPotential = WrappedPotential( {pot: coeff for pot, coeff in zip(pots, coeffs)}, ) else: @@ -591,7 +597,7 @@ def store_potentials(self, parameter_handler: ProperTorsionHandler) -> None: "phase": parameter.phase[n], "idivf": parameter.idivf[n] * unit.dimensionless, } - potential = Potential(parameters=parameters) # type: ignore[assignment] + potential = Potential(parameters=parameters) self.potentials[potential_key] = potential @classmethod diff --git a/openff/interchange/smirnoff/_virtual_sites.py b/openff/interchange/smirnoff/_virtual_sites.py index d9207cc2c..b91880cf6 100644 --- a/openff/interchange/smirnoff/_virtual_sites.py +++ b/openff/interchange/smirnoff/_virtual_sites.py @@ -35,10 +35,10 @@ class SMIRNOFFVirtualSiteCollection(SMIRNOFFCollection): A handler which stores the information necessary to construct virtual sites (virtual particles). """ - key_map: dict[VirtualSiteKey, PotentialKey] = Field( + key_map: dict[VirtualSiteKey, PotentialKey] = Field( # type: ignore[assignment] dict(), description="A mapping between VirtualSiteKey objects and PotentialKey objects.", - ) # type: ignore[assignment] + ) type: Literal["VirtualSites"] = "VirtualSites" expression: Literal[""] = "" diff --git a/setup.cfg b/setup.cfg index 7652dc8c4..34fc1ae07 100644 --- a/setup.cfg +++ b/setup.cfg @@ -126,3 +126,6 @@ ignore_missing_imports = True [mypy-nonbonded_plugins.*] ignore_missing_imports = True + +[mypy-lammps] +ignore_missing_imports = True diff --git a/stubs/nglview/__init__.pyi b/stubs/nglview/__init__.pyi index 419b0c793..45281ef05 100644 --- a/stubs/nglview/__init__.pyi +++ b/stubs/nglview/__init__.pyi @@ -1,11 +1,13 @@ -from typing import Iterable, Union +from typing import Any, Iterable, Union class NGLWidget(object): - ... + def __init__( + self, structure: Any, representations: Any, parameters: Any = None, **kwargs + ): ... def add_representation( self, repr_type: str, - selection: Union[str, Iterable], + selection: Union[str, Iterable] = "all", **kwargs, ): ...