From 95b051c8260b60cff050ebf73dc49092c7140ed5 Mon Sep 17 00:00:00 2001 From: "Matthew W. Thompson" Date: Mon, 30 Sep 2024 14:18:00 -0500 Subject: [PATCH 1/2] BUG: Fix cached charges property --- openff/interchange/_annotations.py | 9 +++++++++ openff/interchange/_tests/test_issues.py | 7 +++++++ openff/interchange/common/_nonbonded.py | 17 +++++++---------- openff/interchange/smirnoff/_nonbonded.py | 17 ++++++++++------- 4 files changed, 33 insertions(+), 17 deletions(-) diff --git a/openff/interchange/_annotations.py b/openff/interchange/_annotations.py index 7cfd18eb..2387d080 100644 --- a/openff/interchange/_annotations.py +++ b/openff/interchange/_annotations.py @@ -57,6 +57,7 @@ def _unit_validator_factory(unit: str) -> Callable: _is_kj_mol, _is_nanometer, _is_degree, + _is_elementary_charge, ) = ( _unit_validator_factory(unit=_unit) for _unit in [ @@ -64,6 +65,7 @@ def _unit_validator_factory(unit: str) -> Callable: "kilojoule / mole", "nanometer", "degree", + "elementary_charge", ] ) @@ -153,6 +155,13 @@ def quantity_json_serializer( WrapSerializer(quantity_json_serializer), ] +_ElementaryChargeQuantity = Annotated[ + Quantity, + WrapValidator(quantity_validator), + AfterValidator(_is_elementary_charge), + WrapSerializer(quantity_json_serializer), +] + _kJMolQuantity = Annotated[ Quantity, WrapValidator(quantity_validator), diff --git a/openff/interchange/_tests/test_issues.py b/openff/interchange/_tests/test_issues.py index 7b381ec2..6a3919ad 100644 --- a/openff/interchange/_tests/test_issues.py +++ b/openff/interchange/_tests/test_issues.py @@ -112,3 +112,10 @@ def test_issue_1031(monkeypatch): # check a few atom names to ensure these didn't end up being empty sets for atom_name in ("NE2", "H3", "HA", "CH3", "CA", "CB", "CE1"): assert atom_name in openff_atom_names + + +def test_issue_1052(sage, ethanol): + """Test that _SMIRNOFFElectrostaticsCollection.charges is populated.""" + out = sage.create_interchange(ethanol.to_topology()) + + assert len(out["Electrostatics"].charges) > 0 diff --git a/openff/interchange/common/_nonbonded.py b/openff/interchange/common/_nonbonded.py index 7d5cb9da..7bc66281 100644 --- a/openff/interchange/common/_nonbonded.py +++ b/openff/interchange/common/_nonbonded.py @@ -3,9 +3,9 @@ from typing import Literal from openff.toolkit import Quantity, unit -from pydantic import Field, PrivateAttr +from pydantic import Field, PrivateAttr, computed_field -from openff.interchange._annotations import _DistanceQuantity +from openff.interchange._annotations import _DistanceQuantity, _ElementaryChargeQuantity from openff.interchange.components.potentials import Collection from openff.interchange.constants import _PME from openff.interchange.models import ( @@ -101,18 +101,15 @@ class ElectrostaticsCollection(_NonbondedCollection): nonperiodic_potential: Literal["Coulomb", "cutoff", "no-cutoff"] = Field("Coulomb") exception_potential: Literal["Coulomb"] = Field("Coulomb") - _charges: dict[ - TopologyKey | LibraryChargeTopologyKey, - Quantity, - ] = PrivateAttr( - default_factory=dict, - ) + # TODO: Charge caching doesn't work when this is defined in the model + # _charges: dict[Any, _ElementaryChargeQuantity] = PrivateAttr(default_factory=dict) _charges_cached: bool = PrivateAttr(default=False) + @computed_field @property def charges( self, - ) -> dict[TopologyKey | LibraryChargeTopologyKey | VirtualSiteKey, Quantity]: + ) -> dict[TopologyKey | LibraryChargeTopologyKey | VirtualSiteKey, _ElementaryChargeQuantity]: """Get the total partial charge on each atom, including virtual sites.""" if len(self._charges) == 0 or self._charges_cached is False: self._charges = self._get_charges(include_virtual_sites=False) @@ -123,7 +120,7 @@ def charges( def _get_charges( self, include_virtual_sites: bool = False, - ) -> dict[TopologyKey | VirtualSiteKey | LibraryChargeTopologyKey, Quantity]: + ) -> dict[TopologyKey | VirtualSiteKey | LibraryChargeTopologyKey, _ElementaryChargeQuantity]: if include_virtual_sites: raise NotImplementedError() diff --git a/openff/interchange/smirnoff/_nonbonded.py b/openff/interchange/smirnoff/_nonbonded.py index ee751f53..3be641cd 100644 --- a/openff/interchange/smirnoff/_nonbonded.py +++ b/openff/interchange/smirnoff/_nonbonded.py @@ -13,9 +13,10 @@ ToolkitAM1BCCHandler, vdWHandler, ) -from pydantic import Field, PrivateAttr +from pydantic import Field, PrivateAttr, computed_field from typing_extensions import Self +from openff.interchange._annotations import _ElementaryChargeQuantity from openff.interchange.common._nonbonded import ( ElectrostaticsCollection, _NonbondedCollection, @@ -272,8 +273,9 @@ class SMIRNOFFElectrostaticsCollection(ElectrostaticsCollection, SMIRNOFFCollect ) # type: ignore[assignment] exception_potential: Literal["Coulomb"] = Field("Coulomb") - _charges = PrivateAttr(default_factory=dict) - _charges_cached: bool + # TODO: Charge caching doesn't work when this is defined in the model + # _charges: dict[Any, _ElementaryChargeQuantity] = PrivateAttr(default_factory=dict) + _charges_cached: bool = PrivateAttr(default=False) @classmethod def allowed_parameter_handlers(cls): @@ -292,14 +294,15 @@ def supported_parameters(cls): @property def _charges_without_virtual_sites( self, - ) -> dict[TopologyKey | LibraryChargeTopologyKey, Quantity]: + ) -> dict[TopologyKey | LibraryChargeTopologyKey, _ElementaryChargeQuantity]: """Get the total partial charge on each atom, excluding virtual sites.""" return self._get_charges(include_virtual_sites=False) + @computed_field @property def charges( self, - ) -> dict[TopologyKey | LibraryChargeTopologyKey | VirtualSiteKey, Quantity]: + ) -> dict[TopologyKey | LibraryChargeTopologyKey | VirtualSiteKey, _ElementaryChargeQuantity]: """Get the total partial charge on each atom, including virtual sites.""" if len(self._charges) == 0 or self._charges_cached is False: self._charges = self._get_charges(include_virtual_sites=True) @@ -310,10 +313,10 @@ def charges( def _get_charges( self, include_virtual_sites=True, - ) -> dict[TopologyKey | LibraryChargeTopologyKey | VirtualSiteKey, Quantity]: + ) -> dict[TopologyKey | LibraryChargeTopologyKey | VirtualSiteKey, _ElementaryChargeQuantity]: """Get the total partial charge on each atom or particle.""" # Keyed by index for atoms and by VirtualSiteKey for virtual sites. - charges: dict[VirtualSiteKey | int, Quantity] = dict() + charges: dict[VirtualSiteKey | int, _ElementaryChargeQuantity] = dict() for topology_key, potential_key in self.key_map.items(): potential = self.potentials[potential_key] From 2cace3f0ac0733d3ac7cbade7ab9f3ffb1eb0668 Mon Sep 17 00:00:00 2001 From: "Matthew W. Thompson" Date: Mon, 30 Sep 2024 14:24:49 -0500 Subject: [PATCH 2/2] DOC: Update annotations, release history --- docs/releasehistory.md | 1 + openff/interchange/common/_nonbonded.py | 4 ++-- openff/interchange/components/potentials.py | 2 +- openff/interchange/smirnoff/_nonbonded.py | 2 +- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/docs/releasehistory.md b/docs/releasehistory.md index 3a73dfcc..e41babb1 100644 --- a/docs/releasehistory.md +++ b/docs/releasehistory.md @@ -22,6 +22,7 @@ Please note that all releases prior to a version 1.0.0 are considered pre-releas * Several classes and methods which were deprecated in the 0.3 line of releases are now removed. * Previously-deprecated examples are removed. * `ProperTorsionKey` no longer accepts an empty tuple as atom indices. +* Fixes a regression in which some `ElectrostaticsCollection.charges` properties did not return cached values. ## 0.3.30 - 2024-08 diff --git a/openff/interchange/common/_nonbonded.py b/openff/interchange/common/_nonbonded.py index 7bc66281..ecb47d60 100644 --- a/openff/interchange/common/_nonbonded.py +++ b/openff/interchange/common/_nonbonded.py @@ -105,13 +105,13 @@ class ElectrostaticsCollection(_NonbondedCollection): # _charges: dict[Any, _ElementaryChargeQuantity] = PrivateAttr(default_factory=dict) _charges_cached: bool = PrivateAttr(default=False) - @computed_field + @computed_field # type: ignore[misc] @property def charges( self, ) -> dict[TopologyKey | LibraryChargeTopologyKey | VirtualSiteKey, _ElementaryChargeQuantity]: """Get the total partial charge on each atom, including virtual sites.""" - if len(self._charges) == 0 or self._charges_cached is False: + if len(self._charges) == 0 or self._charges_cached is False: # type: ignore[has-type] self._charges = self._get_charges(include_virtual_sites=False) self._charges_cached = True diff --git a/openff/interchange/components/potentials.py b/openff/interchange/components/potentials.py index 9d6b334e..5161f2cb 100644 --- a/openff/interchange/components/potentials.py +++ b/openff/interchange/components/potentials.py @@ -33,7 +33,7 @@ if has_package("jax"): from jax import Array else: - Array = Any + Array = Any # type: ignore class Potential(_BaseModel): diff --git a/openff/interchange/smirnoff/_nonbonded.py b/openff/interchange/smirnoff/_nonbonded.py index 3be641cd..74f360f2 100644 --- a/openff/interchange/smirnoff/_nonbonded.py +++ b/openff/interchange/smirnoff/_nonbonded.py @@ -298,7 +298,7 @@ def _charges_without_virtual_sites( """Get the total partial charge on each atom, excluding virtual sites.""" return self._get_charges(include_virtual_sites=False) - @computed_field + @computed_field # type: ignore[misc] @property def charges( self,