From 1f9608e435233123dbbb9a2644feedc04f7dd281 Mon Sep 17 00:00:00 2001 From: Brent Westbrook Date: Wed, 2 Oct 2024 13:59:26 -0400 Subject: [PATCH] use update instead of assignment to set _charges --- openff/interchange/common/_nonbonded.py | 7 +++---- openff/interchange/smirnoff/_nonbonded.py | 5 ++--- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/openff/interchange/common/_nonbonded.py b/openff/interchange/common/_nonbonded.py index ecb47d60..d55b0de6 100644 --- a/openff/interchange/common/_nonbonded.py +++ b/openff/interchange/common/_nonbonded.py @@ -1,6 +1,6 @@ import abc from collections.abc import Iterable -from typing import Literal +from typing import Any, Literal from openff.toolkit import Quantity, unit from pydantic import Field, PrivateAttr, computed_field @@ -101,8 +101,7 @@ class ElectrostaticsCollection(_NonbondedCollection): nonperiodic_potential: Literal["Coulomb", "cutoff", "no-cutoff"] = Field("Coulomb") exception_potential: Literal["Coulomb"] = Field("Coulomb") - # TODO: Charge caching doesn't work when this is defined in the model - # _charges: dict[Any, _ElementaryChargeQuantity] = PrivateAttr(default_factory=dict) + _charges: dict[Any, _ElementaryChargeQuantity] = PrivateAttr(default_factory=dict) _charges_cached: bool = PrivateAttr(default=False) @computed_field # type: ignore[misc] @@ -112,7 +111,7 @@ def charges( ) -> dict[TopologyKey | LibraryChargeTopologyKey | VirtualSiteKey, _ElementaryChargeQuantity]: """Get the total partial charge on each atom, including virtual sites.""" if len(self._charges) == 0 or self._charges_cached is False: # type: ignore[has-type] - self._charges = self._get_charges(include_virtual_sites=False) + self._charges.update(self._get_charges(include_virtual_sites=False)) self._charges_cached = True return self._charges diff --git a/openff/interchange/smirnoff/_nonbonded.py b/openff/interchange/smirnoff/_nonbonded.py index 74f360f2..bde6bc3d 100644 --- a/openff/interchange/smirnoff/_nonbonded.py +++ b/openff/interchange/smirnoff/_nonbonded.py @@ -273,8 +273,7 @@ class SMIRNOFFElectrostaticsCollection(ElectrostaticsCollection, SMIRNOFFCollect ) # type: ignore[assignment] exception_potential: Literal["Coulomb"] = Field("Coulomb") - # TODO: Charge caching doesn't work when this is defined in the model - # _charges: dict[Any, _ElementaryChargeQuantity] = PrivateAttr(default_factory=dict) + _charges: dict[Any, _ElementaryChargeQuantity] = PrivateAttr(default_factory=dict) _charges_cached: bool = PrivateAttr(default=False) @classmethod @@ -305,7 +304,7 @@ def charges( ) -> dict[TopologyKey | LibraryChargeTopologyKey | VirtualSiteKey, _ElementaryChargeQuantity]: """Get the total partial charge on each atom, including virtual sites.""" if len(self._charges) == 0 or self._charges_cached is False: - self._charges = self._get_charges(include_virtual_sites=True) + self._charges.update(self._get_charges(include_virtual_sites=True)) self._charges_cached = True return self._charges