Skip to content

Commit

Permalink
Merge pull request #964 from openforcefield/new-models
Browse files Browse the repository at this point in the history
Migrate to Pydantic v2
  • Loading branch information
mattwthompson authored Jun 18, 2024
2 parents 40bb3e8 + 3e6e837 commit 956fc87
Show file tree
Hide file tree
Showing 70 changed files with 1,147 additions and 752 deletions.
5 changes: 3 additions & 2 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ on:
push:
branches:
- main
- develop
pull_request:
branches:
- main
- develop
schedule:
- cron: "0 0 * * *"
workflow_dispatch:
Expand Down Expand Up @@ -71,8 +73,7 @@ jobs:
- name: Install OpenMM
if: ${{ matrix.openmm == true }}
run: |
micromamba install openmm "smirnoff-plugins =2024" -c conda-forge
pip install git+https://github.com/jthorton/de-forcefields.git
micromamba install openmm -c conda-forge
- name: Uninstall OpenMM
if: ${{ matrix.openmm == false && matrix.openeye == true }}
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/examples.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ on:
push:
branches:
- main
- v0.3.0-staging
- develop
pull_request:
branches:
- main
- v0.3.0-staging
- develop
schedule:
- cron: "0 0 * * *"
workflow_dispatch:
Expand Down
3 changes: 0 additions & 3 deletions devtools/conda-envs/beta_env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ dependencies:
- openmm >=7.6
# OpenFF stack
- openff-toolkit >=0.15.2
- openff-models
- openff-nagl ~=0.3.7
- openff-nagl-models =0.1
# Optional features
Expand Down Expand Up @@ -42,5 +41,3 @@ dependencies:
- typing-extensions
- types-setuptools
- pandas-stubs >=1.2.0.56
- pip:
- git+https://github.com/jthorton/de-forcefields.git
7 changes: 2 additions & 5 deletions devtools/conda-envs/dev_env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ dependencies:
# OpenFF stack
- openff-toolkit ~=0.16
- openff-interchange-base
- openff-models
- smirnoff-plugins =2024
# smirnoff-plugins =2024
- openff-nagl
- openff-nagl-models
- ambertools =23
Expand All @@ -30,7 +29,7 @@ dependencies:
- pytest-xdist
- pytest-randomly
- nbval
# de-forcefields # needs new release
# de-forcefields # add back after smirnoff-plugins update
# Drivers
- gromacs
- lammps >=2023.08.02
Expand All @@ -53,5 +52,3 @@ dependencies:
- flake8
- snakeviz
- tuna
- pip:
- git+https://github.com/jthorton/de-forcefields.git
9 changes: 4 additions & 5 deletions devtools/conda-envs/docs_env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@ dependencies:
- python =3.10
- pip
- numpy =1
- pydantic =1
- openff-toolkit-base =0.15.2
- openff-models
- pydantic =2
- openff-toolkit-base
- openmm >=7.6
- mbuild
- foyer >=0.12.1
Expand All @@ -20,8 +19,8 @@ dependencies:
# readthedocs dependencies
- myst-parser
- numpydoc
- autodoc-pydantic
- sphinx>=4.4.0,<5
- autodoc-pydantic =2
- sphinx ~=4.4
- sphinxcontrib-mermaid
- sphinx-notfound-page
- pip:
Expand Down
1 change: 0 additions & 1 deletion devtools/conda-envs/examples_env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ dependencies:
- openmm
# OpenFF stack
- openff-toolkit
- openff-models
- openff-nagl
- openff-nagl-models
- ambertools =23
Expand Down
2 changes: 0 additions & 2 deletions devtools/conda-envs/test_env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,10 @@ dependencies:
# OpenFF stack
- openff-toolkit-base >=0.16
- openff-units
- openff-models
- ambertools =23
# Needs to be explicitly listed to not be dropped when AmberTools is removed
- rdkit
# Optional features
# GMSO does not support Pydantic 2; should come in release after 0.12.0
- foyer >=0.12.1
- mbuild
- gmso =0.12
Expand Down
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@
autodoc_default_options = {
"member-order": "bysource",
"undoc-members": True,
"inherited-members": False,
"inherited-members": [],
"show-inheritance": True,
}
autodoc_preserve_defaults = True
Expand Down
2 changes: 0 additions & 2 deletions docs/using/plugins.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ from openff.toolkit.typing.engines.smirnoff.parameters import (

class BuckinghamHandler(ParameterHandler):
class BuckinghamType(ParameterType):
_VALENCE_TYPE = "Atom"
_ELEMENT_NAME = "Atom"

a = ParameterAttribute(default=None, unit=unit.kilojoule_per_mole)
Expand Down Expand Up @@ -138,7 +137,6 @@ Notice that
* `BuckinghamHandler` (the "handler class") is a subclass of `ParameterHandler`
* `BuckinghamType` (the "type class")
* is a subclass of `ParameterType`
* defines `"Atom"` as its `_VALENCE_TYPE`, or chemical environment
* defines `"Atom"` as its `_ELEMENT_TYPE`, which defines how it is serialized
* has unit-tagged attributes `a`, `b`, and `c`, corresponding to particular values for each parameter
* the handler class also
Expand Down
2 changes: 1 addition & 1 deletion examples/lammps/lammps.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
"version": "3.10.14"
}
},
"nbformat": 4,
Expand Down
5 changes: 4 additions & 1 deletion openff/interchange/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ def __getattr__(name) -> ModuleType:
"""
module = _objects.get(name)
if module is not None:
return importlib.import_module(module).__dict__[name]
try:
return importlib.import_module(module).__dict__[name]
except ImportError as error:
raise ImportError from error

raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

Expand Down
226 changes: 226 additions & 0 deletions openff/interchange/_annotations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
import functools
from collections.abc import Callable
from typing import Annotated, Any

import numpy
from openff.toolkit import Quantity
from pydantic import (
AfterValidator,
BeforeValidator,
ValidationInfo,
ValidatorFunctionWrapHandler,
WrapSerializer,
WrapValidator,
)


def _has_compatible_dimensionality(
quantity: Quantity,
unit: str,
convert: bool,
) -> Quantity:
"""Check if a Quantity has the same dimensionality as a given unit and optionally convert."""
if quantity.is_compatible_with(unit):
if convert:
return quantity.to(unit)
else:
return quantity
else:
raise ValueError(
f"Dimensionality of {quantity=} is not compatible with {unit=}",
)


def _dimensionality_valiator_factory(unit: str) -> Callable:
"""Return a function, meant to be passed to a validator, that checks for a specific unit."""
return functools.partial(_has_compatible_dimensionality, unit=unit, convert=False)


def _unit_validator_factory(unit: str) -> Callable:
"""Return a function, meant to be passed to a validator, that checks for a specific unit."""
return functools.partial(_has_compatible_dimensionality, unit=unit, convert=True)


(
_is_distance,
_is_velocity,
) = (
_dimensionality_valiator_factory(unit=_unit)
for _unit in [
"nanometer",
"nanometer / picosecond",
]
)

(
_is_dimensionless,
_is_kj_mol,
_is_nanometer,
_is_degree,
) = (
_unit_validator_factory(unit=_unit)
for _unit in [
"dimensionless",
"kilojoule / mole",
"nanometer",
"degree",
]
)


def quantity_validator(
value: str | Quantity | dict,
handler: ValidatorFunctionWrapHandler,
info: ValidationInfo,
) -> Quantity:
"""Take Quantity-like objects and convert them to Quantity objects."""
if info.mode == "json":
assert isinstance(value, dict), "Quantity must be in dict form here."

# this is coupled to how a Quantity looks in JSON
return Quantity(value["value"], value["unit"])

# some more work may be needed to work with arrays, lists, tuples, etc.

assert info.mode == "python"

if isinstance(value, Quantity):
return value
elif isinstance(value, str):
return Quantity(value)
elif isinstance(value, dict):
return Quantity(value["value"], value["unit"])
if "openmm" in str(type(value)):
from openff.units.openmm import from_openmm

return from_openmm(value)
else:
raise ValueError(f"Invalid type {type(value)} for Quantity")


def quantity_json_serializer(
quantity: Quantity,
nxt,
) -> dict:
"""Serialize a Quantity to a JSON-compatible dictionary."""
magnitude = quantity.m

if isinstance(magnitude, numpy.ndarray):
# This could be something fancier, list a bytestring
magnitude = magnitude.tolist()

return {
"value": magnitude,
"unit": str(quantity.units),
}


# Pydantic v2 likes to marry validators and serializers to types with Annotated
# https://docs.pydantic.dev/latest/concepts/validators/#annotated-validators
_Quantity = Annotated[
Quantity,
WrapValidator(quantity_validator),
WrapSerializer(quantity_json_serializer),
]

_DimensionlessQuantity = Annotated[
Quantity,
WrapValidator(quantity_validator),
AfterValidator(_is_dimensionless),
WrapSerializer(quantity_json_serializer),
]

_DistanceQuantity = Annotated[
Quantity,
WrapValidator(quantity_validator),
AfterValidator(_is_distance),
WrapSerializer(quantity_json_serializer),
]

_LengthQuantity = _DistanceQuantity

_VelocityQuantity = Annotated[
Quantity,
WrapValidator(quantity_validator),
AfterValidator(_is_velocity),
WrapSerializer(quantity_json_serializer),
]

_DegreeQuantity = Annotated[
Quantity,
WrapValidator(quantity_validator),
AfterValidator(_is_degree),
WrapSerializer(quantity_json_serializer),
]

_kJMolQuantity = Annotated[
Quantity,
WrapValidator(quantity_validator),
AfterValidator(_is_kj_mol),
WrapSerializer(quantity_json_serializer),
]


def _is_positions_shape(quantity: Quantity) -> Quantity:
if quantity.m.shape[1] == 3:
return quantity
else:
raise ValueError(
f"Quantity {quantity} of wrong shape ({quantity.shape}) to be positions.",
)


def _duck_to_nanometer(value: Any):
"""Cast list or ndarray without units to Quantity[ndarray] of nanometer."""
if isinstance(value, (list, numpy.ndarray)):
return Quantity(value, "nanometer")
else:
return value


_PositionsQuantity = Annotated[
Quantity,
WrapValidator(quantity_validator),
AfterValidator(_is_nanometer),
AfterValidator(_is_positions_shape),
BeforeValidator(_duck_to_nanometer),
WrapSerializer(quantity_json_serializer),
]


def _is_box_shape(quantity) -> Quantity:
if quantity.m.shape == (3, 3):
return quantity
elif quantity.m.shape == (3,):
return numpy.eye(3) * quantity
else:
raise ValueError(f"Quantity {quantity} is not a box.")


def _unwrap_list_of_openmm_quantities(value: Any):
"""Unwrap a list of OpenMM quantities to a single Quantity."""
if isinstance(value, list):
if any(["openmm" in str(type(element)) for element in value]):
from openff.units.openmm import from_openmm

if len({element.unit for element in value}) != 1:
raise ValueError("All units must be the same.")

return from_openmm(value)

else:
return value

else:
return value


_BoxQuantity = Annotated[
Quantity,
WrapValidator(quantity_validator),
AfterValidator(_is_distance),
AfterValidator(_is_box_shape),
BeforeValidator(_duck_to_nanometer),
BeforeValidator(_unwrap_list_of_openmm_quantities),
WrapSerializer(quantity_json_serializer),
]
Loading

0 comments on commit 956fc87

Please sign in to comment.