Skip to content

Commit

Permalink
Merge pull request #1069 from openforcefield/charge-attrs
Browse files Browse the repository at this point in the history
Use update instead of assignment to set _charges
  • Loading branch information
mattwthompson authored Oct 2, 2024
2 parents 360b60c + 1f9608e commit bdd2177
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 7 deletions.
7 changes: 3 additions & 4 deletions openff/interchange/common/_nonbonded.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import abc
from collections.abc import Iterable
from typing import Literal
from typing import Any, Literal

from openff.toolkit import Quantity, unit
from pydantic import Field, PrivateAttr, computed_field
Expand Down Expand Up @@ -101,8 +101,7 @@ class ElectrostaticsCollection(_NonbondedCollection):
nonperiodic_potential: Literal["Coulomb", "cutoff", "no-cutoff"] = Field("Coulomb")
exception_potential: Literal["Coulomb"] = Field("Coulomb")

# TODO: Charge caching doesn't work when this is defined in the model
# _charges: dict[Any, _ElementaryChargeQuantity] = PrivateAttr(default_factory=dict)
_charges: dict[Any, _ElementaryChargeQuantity] = PrivateAttr(default_factory=dict)
_charges_cached: bool = PrivateAttr(default=False)

@computed_field # type: ignore[misc]
Expand All @@ -112,7 +111,7 @@ def charges(
) -> dict[TopologyKey | LibraryChargeTopologyKey | VirtualSiteKey, _ElementaryChargeQuantity]:
"""Get the total partial charge on each atom, including virtual sites."""
if len(self._charges) == 0 or self._charges_cached is False: # type: ignore[has-type]
self._charges = self._get_charges(include_virtual_sites=False)
self._charges.update(self._get_charges(include_virtual_sites=False))
self._charges_cached = True

return self._charges
Expand Down
5 changes: 2 additions & 3 deletions openff/interchange/smirnoff/_nonbonded.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,7 @@ class SMIRNOFFElectrostaticsCollection(ElectrostaticsCollection, SMIRNOFFCollect
) # type: ignore[assignment]
exception_potential: Literal["Coulomb"] = Field("Coulomb")

# TODO: Charge caching doesn't work when this is defined in the model
# _charges: dict[Any, _ElementaryChargeQuantity] = PrivateAttr(default_factory=dict)
_charges: dict[Any, _ElementaryChargeQuantity] = PrivateAttr(default_factory=dict)
_charges_cached: bool = PrivateAttr(default=False)

@classmethod
Expand Down Expand Up @@ -305,7 +304,7 @@ def charges(
) -> dict[TopologyKey | LibraryChargeTopologyKey | VirtualSiteKey, _ElementaryChargeQuantity]:
"""Get the total partial charge on each atom, including virtual sites."""
if len(self._charges) == 0 or self._charges_cached is False:
self._charges = self._get_charges(include_virtual_sites=True)
self._charges.update(self._get_charges(include_virtual_sites=True))
self._charges_cached = True

return self._charges
Expand Down

0 comments on commit bdd2177

Please sign in to comment.