diff --git a/docs/source/models.rst b/docs/source/models.rst index 0e54a29b7..b5e8be5cd 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -36,7 +36,95 @@ Once you have trained a model you should have a checkpoint that you can load for .. note:: Some models take additional inputs such as the charge :code:`q` and the spin :code:`s` of the atoms depending on the chosen priors/outputs. Check the documentation of the model you are using to see if this is the case. .. note:: When periodic boundary conditions are required, modules typically offer the possibility of providing the box vectors at construction and/or as an argument to the forward pass. Check the documentation of the class you are using to see if this is the case. + + + + +Integration with MD packages +----------------------------- + +It is possible to use the Neural Network Potentials in TorchMD-Net as force fields for Molecular Dynamics. + +OpenMM +~~~~~~ + +The `OpenMM-Torch `_ plugin can be used to load :ref:`pretrained-models` as force fields in `OpenMM `_. In order to do that one needs a translation layer between :py:mod:`TorchMD_Net ` and `TorchForce `_. This wrapper needs to take into account the different parameters and units (depending on the :ref:`Dataset ` used to train the model) in both. + +We provide here a minimal example of the wrapper class, but a complete example is provided under the `examples` folder. + +.. code:: python + + import torch + from torch import Tensor, nn + import openmm + import openmmtorch + from torchmdnet.models.model import load_model + # This is a wrapper that links an OpenMM Force with a TorchMD-Net model + class Wrapper(nn.Module): + + def __init__(self, embeddings: Tensor, checkpoint_path: str): + super(Wrapper, self).__init__() + # The embeddings used to train the model, typically atomic numbers + self.embeddings = embeddings + # We let OpenMM compute the forces from the energies + self.model = load_model(checkpoint_path, derivative=False) + + def forward(self, positions: Tensor) -> Tensor: + # OpenMM works with nanometer positions and kilojoule per mole energies + # Depending on the model, you might need to convert the units + positions = positions.to(torch.float32) * 10.0 # nm -> A + energy = self.model(z=self.embeddings, pos=positions)[0] + return energy * 96.4916 # eV -> kJ/mol + + model = Wrapper(embeddings=torch.tensor([1, 6, 7, 8, 9]), checkpoint_path="/path/to/checkpoint/my_checkpoint.ckpt") + model = torch.jit.script(model) # Models need to be scripted to be used in OpenMM + # The model can be used as a force field in OpenMM + force = openmmtorch.TorchForce(model) + # Add the force to an OpenMM system + system = openmm.System() + system.addForce(force) + + +.. note:: See :ref:`training ` for more information on how to train a model. + +.. warning:: The conversion factors are specific to the dataset used to train the model. Check the documentation of the dataset you are using to see if this is the case. + +.. note:: See the `OpenMM-Torch `_ documentation for more information on additional functionality (such as periodic boundary conditions or CUDA graph support). + + +TorchMD +~~~~~~~ + +Integration with `TorchMD `_ is carried out via :py:mod:`torchmdnet.calculators.External`. Refer to its documentation for more information on additional functionality. + +.. code:: python + + import torch + import torchmd + from torchmdnet.calculators import External + # Load the model + embeddings = torch.tensor([1, 6, 7, 8, 9]) + model = External("/path/to/checkpoint/my_checkpoint.ckpt, embeddings) + # Use the calculator in a TorchMD simulation + from torchmd.forces import Forces + parameters = # Your TorchMD parameters here + torchmd_forces = Forces(parameters, external=model) + # You can now pass torchmd_forces to a TorchMD Integrator + +Additionally, the calculator can be specified in the configuration file of a TorchMD simulation via the `external` key. + + +.. code:: yaml + ... + external: + module: torchmdnet.calculators + file: /path/to/checkpoint/my_checkpoint.ckpt + embeddings: [1, 6, 7, 8, 9] + ... + +.. warning:: Unit conversion might be required depending on the dataset used to train the model. Check the documentation of the dataset you are using to see if this is the case. + Available Models ---------------- diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 0248f55c0..ff7b1fd24 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -24,6 +24,8 @@ For example, to train the Equivariant Transformer on the QM9 dataset with the ar Run `torchmd-train --help` to see all available options and their descriptions. +.. _pretrained-models: + Pretrained Models ================= @@ -95,7 +97,8 @@ In order to train models on multiple nodes some environment variables have to be - Due to the way PyTorch Lightning calculates the number of required DDP processes, all nodes must use the same number of GPUs. Otherwise training will not start or crash. - We observe a 50x decrease in performance when mixing nodes with different GPU architectures (tested with RTX 2080 Ti and RTX 3090). - Some CUDA systems might hang during a multi-GPU parallel training. Try ``export NCCL_P2P_DISABLE=1``, which disables direct peer to peer GPU communication. - + + Developer Guide --------------- diff --git a/examples/README.md b/examples/README.md index 59c993992..3778f59be 100644 --- a/examples/README.md +++ b/examples/README.md @@ -36,3 +36,9 @@ pos = torch.rand(6, 3) batch = torch.tensor([0, 0, 0, 1, 1, 1], dtype=torch.long) energies, forces = model(z, pos, batch) ``` + +## Running MD simulations with OpenMM + +The example `openmm-integration.py` shows how to run an MD simulation using a pretrained TorchMD-Net module as a force field. +The connection between OpenMM and TorchMD-Net is done via [OpenMM-Torch](https://github.com/openmm/openmm-torch). +See the [documentation](https://torchmd-net.readthedocs.io/en/latest/models.html#neural-network-potentials) for more details on this integration. diff --git a/examples/openmm-integration.py b/examples/openmm-integration.py new file mode 100644 index 000000000..8bfdfa450 --- /dev/null +++ b/examples/openmm-integration.py @@ -0,0 +1,62 @@ +# This script shows how to use a TorchMD-Net model as a force field in OpenMM +# We will run some simulation steps with OpenMM on chignolin using a pretrained model. + +try: + import openmm + import openmmtorch +except ImportError: + raise ImportError("Please install OpenMM and OpenMM-Torch (you can use conda install -c conda-forge openmm openmm-torch)") + +import sys +import torch +from openmm.app import PDBFile, StateDataReporter, Simulation +from openmm import Platform, System +from openmm import LangevinMiddleIntegrator +from openmm.unit import * +from torchmdnet.models.model import load_model + + +# This is a wrapper that links an OpenMM Force with a TorchMD-Net model +class Wrapper(torch.nn.Module): + + def __init__(self, embeddings, model): + super(Wrapper, self).__init__() + self.embeddings = embeddings + # Load a model checkpoint from a previous training run. + # You can generate the checkpoint using the examples in this folder, for instance: + # torchmd-train --conf TensorNet-ANI1X.yaml + + # OpenMM will compute the forces by backpropagating the energy, + # so we can load the model with derivative=False + # In this particular example I find that the maximum number of neighbors required is around 40 + self.model = load_model(model, derivative=False, max_num_neighbors=40) + + def forward(self, positions): + # OpenMM works with nanometer positions and kilojoule per mole energies + # Depending on the model, you might need to convert the units + positions = positions.to(torch.float32) * 10.0 # nm -> A + energy = self.model(z=self.embeddings, pos=positions)[0] + return energy * 96.4916 # eV -> kJ/mol + + +pdb = PDBFile("../benchmarks/systems/chignolin.pdb") + +# Typically models are trained using atomic numbers as embeddings +z = [i.element.atomic_number for i in pdb.topology.atoms()] +z = torch.tensor(z, dtype=torch.long) + +model = torch.jit.script(Wrapper(z, "model.ckpt")) +# Create a TorchForce object from the model +torch_force = openmmtorch.TorchForce(model) + +system = System() +# Create an OpenMM system and add the TorchForce +for i in range(pdb.topology.getNumAtoms()): + system.addParticle(1.0) +system.addForce(torch_force) +integrator = LangevinMiddleIntegrator(298.15*kelvin, 1/picosecond, 2*femtosecond) +platform = Platform.getPlatformByName('CPU') +simulation = Simulation(pdb.topology, system, integrator, platform) +simulation.context.setPositions(pdb.positions) +simulation.reporters.append(StateDataReporter(sys.stdout, 1, step=True, potentialEnergy=True, temperature=True)) +simulation.step(10)