Skip to content

Commit

Permalink
Merge pull request #1066 from openforcefield/fix-1052
Browse files Browse the repository at this point in the history
Fix default `.charges`
  • Loading branch information
mattwthompson authored Oct 2, 2024
2 parents 273d27a + 2cace3f commit 360b60c
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 19 deletions.
1 change: 1 addition & 0 deletions docs/releasehistory.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Please note that all releases prior to a version 1.0.0 are considered pre-releas
* Several classes and methods which were deprecated in the 0.3 line of releases are now removed.
* Previously-deprecated examples are removed.
* `ProperTorsionKey` no longer accepts an empty tuple as atom indices.
* Fixes a regression in which some `ElectrostaticsCollection.charges` properties did not return cached values.

## 0.3.30 - 2024-08

Expand Down
9 changes: 9 additions & 0 deletions openff/interchange/_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,15 @@ def _unit_validator_factory(unit: str) -> Callable:
_is_kj_mol,
_is_nanometer,
_is_degree,
_is_elementary_charge,
) = (
_unit_validator_factory(unit=_unit)
for _unit in [
"dimensionless",
"kilojoule / mole",
"nanometer",
"degree",
"elementary_charge",
]
)

Expand Down Expand Up @@ -153,6 +155,13 @@ def quantity_json_serializer(
WrapSerializer(quantity_json_serializer),
]

_ElementaryChargeQuantity = Annotated[
Quantity,
WrapValidator(quantity_validator),
AfterValidator(_is_elementary_charge),
WrapSerializer(quantity_json_serializer),
]

_kJMolQuantity = Annotated[
Quantity,
WrapValidator(quantity_validator),
Expand Down
7 changes: 7 additions & 0 deletions openff/interchange/_tests/test_issues.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,10 @@ def test_issue_1031(monkeypatch):
# check a few atom names to ensure these didn't end up being empty sets
for atom_name in ("NE2", "H3", "HA", "CH3", "CA", "CB", "CE1"):
assert atom_name in openff_atom_names


def test_issue_1052(sage, ethanol):
"""Test that _SMIRNOFFElectrostaticsCollection.charges is populated."""
out = sage.create_interchange(ethanol.to_topology())

assert len(out["Electrostatics"].charges) > 0
19 changes: 8 additions & 11 deletions openff/interchange/common/_nonbonded.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from typing import Literal

from openff.toolkit import Quantity, unit
from pydantic import Field, PrivateAttr
from pydantic import Field, PrivateAttr, computed_field

from openff.interchange._annotations import _DistanceQuantity
from openff.interchange._annotations import _DistanceQuantity, _ElementaryChargeQuantity
from openff.interchange.components.potentials import Collection
from openff.interchange.constants import _PME
from openff.interchange.models import (
Expand Down Expand Up @@ -101,20 +101,17 @@ class ElectrostaticsCollection(_NonbondedCollection):
nonperiodic_potential: Literal["Coulomb", "cutoff", "no-cutoff"] = Field("Coulomb")
exception_potential: Literal["Coulomb"] = Field("Coulomb")

_charges: dict[
TopologyKey | LibraryChargeTopologyKey,
Quantity,
] = PrivateAttr(
default_factory=dict,
)
# TODO: Charge caching doesn't work when this is defined in the model
# _charges: dict[Any, _ElementaryChargeQuantity] = PrivateAttr(default_factory=dict)
_charges_cached: bool = PrivateAttr(default=False)

@computed_field # type: ignore[misc]
@property
def charges(
self,
) -> dict[TopologyKey | LibraryChargeTopologyKey | VirtualSiteKey, Quantity]:
) -> dict[TopologyKey | LibraryChargeTopologyKey | VirtualSiteKey, _ElementaryChargeQuantity]:
"""Get the total partial charge on each atom, including virtual sites."""
if len(self._charges) == 0 or self._charges_cached is False:
if len(self._charges) == 0 or self._charges_cached is False: # type: ignore[has-type]
self._charges = self._get_charges(include_virtual_sites=False)
self._charges_cached = True

Expand All @@ -123,7 +120,7 @@ def charges(
def _get_charges(
self,
include_virtual_sites: bool = False,
) -> dict[TopologyKey | VirtualSiteKey | LibraryChargeTopologyKey, Quantity]:
) -> dict[TopologyKey | VirtualSiteKey | LibraryChargeTopologyKey, _ElementaryChargeQuantity]:
if include_virtual_sites:
raise NotImplementedError()

Expand Down
2 changes: 1 addition & 1 deletion openff/interchange/components/potentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
if has_package("jax"):
from jax import Array
else:
Array = Any
Array = Any # type: ignore


class Potential(_BaseModel):
Expand Down
17 changes: 10 additions & 7 deletions openff/interchange/smirnoff/_nonbonded.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
ToolkitAM1BCCHandler,
vdWHandler,
)
from pydantic import Field, PrivateAttr
from pydantic import Field, PrivateAttr, computed_field
from typing_extensions import Self

from openff.interchange._annotations import _ElementaryChargeQuantity
from openff.interchange.common._nonbonded import (
ElectrostaticsCollection,
_NonbondedCollection,
Expand Down Expand Up @@ -272,8 +273,9 @@ class SMIRNOFFElectrostaticsCollection(ElectrostaticsCollection, SMIRNOFFCollect
) # type: ignore[assignment]
exception_potential: Literal["Coulomb"] = Field("Coulomb")

_charges = PrivateAttr(default_factory=dict)
_charges_cached: bool
# TODO: Charge caching doesn't work when this is defined in the model
# _charges: dict[Any, _ElementaryChargeQuantity] = PrivateAttr(default_factory=dict)
_charges_cached: bool = PrivateAttr(default=False)

@classmethod
def allowed_parameter_handlers(cls):
Expand All @@ -292,14 +294,15 @@ def supported_parameters(cls):
@property
def _charges_without_virtual_sites(
self,
) -> dict[TopologyKey | LibraryChargeTopologyKey, Quantity]:
) -> dict[TopologyKey | LibraryChargeTopologyKey, _ElementaryChargeQuantity]:
"""Get the total partial charge on each atom, excluding virtual sites."""
return self._get_charges(include_virtual_sites=False)

@computed_field # type: ignore[misc]
@property
def charges(
self,
) -> dict[TopologyKey | LibraryChargeTopologyKey | VirtualSiteKey, Quantity]:
) -> dict[TopologyKey | LibraryChargeTopologyKey | VirtualSiteKey, _ElementaryChargeQuantity]:
"""Get the total partial charge on each atom, including virtual sites."""
if len(self._charges) == 0 or self._charges_cached is False:
self._charges = self._get_charges(include_virtual_sites=True)
Expand All @@ -310,10 +313,10 @@ def charges(
def _get_charges(
self,
include_virtual_sites=True,
) -> dict[TopologyKey | LibraryChargeTopologyKey | VirtualSiteKey, Quantity]:
) -> dict[TopologyKey | LibraryChargeTopologyKey | VirtualSiteKey, _ElementaryChargeQuantity]:
"""Get the total partial charge on each atom or particle."""
# Keyed by index for atoms and by VirtualSiteKey for virtual sites.
charges: dict[VirtualSiteKey | int, Quantity] = dict()
charges: dict[VirtualSiteKey | int, _ElementaryChargeQuantity] = dict()

for topology_key, potential_key in self.key_map.items():
potential = self.potentials[potential_key]
Expand Down

0 comments on commit 360b60c

Please sign in to comment.